added AVX for linear basis function

This commit is contained in:
2016-02-16 22:07:34 +01:00
parent e5dddc926a
commit 435524fb6b
3 changed files with 43 additions and 13 deletions

6
.gitignore vendored
View File

@@ -1,9 +1,3 @@
NN.kdev4
.kdev4
*.o
*.a
*.so
*.nm
/doc/html/* /doc/html/*
!/doc/html/doxy-boot.js !/doc/html/doxy-boot.js
!/doc/html/jquery.powertip.min.js !/doc/html/jquery.powertip.min.js

View File

@@ -2,22 +2,27 @@ cmake_minimum_required(VERSION 3.2)
project(NeuralNetwork CXX) project(NeuralNetwork CXX)
OPTION(BUILD_SHARED_LIBS "Build also shared library." ON) OPTION(BUILD_SHARED_LIBS "Build also shared library." ON)
OPTION(USE_AVX "IF avx should be used." ON)
OPTION(USE_SSE "IF sse should be used." ON) OPTION(USE_SSE "IF sse should be used." ON)
OPTION(USE_SSE2 "IF only sse2 should be used." OFF) OPTION(USE_SSE2 "IF only sse2 should be used." OFF)
OPTION(ENABLE_TESTS "enables tests" ON) OPTION(ENABLE_TESTS "enables tests" ON)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Weffc++ -Wshadow -Wstrict-aliasing -ansi -Woverloaded-virtual -Wdelete-non-virtual-dtor -Wno-unused-function") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Weffc++ -Wshadow -Wstrict-aliasing -ansi -Woverloaded-virtual -Wdelete-non-virtual-dtor -Wno-unused-function")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g")
#set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -pthread") #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -pthread")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native -O3") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native -O3")
if(USE_SSE) if(USE_AVX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_AVX")
elseif(USE_SSE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.2 -DUSE_SSE") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.2 -DUSE_SSE")
if(USE_SSE2) if(USE_SSE2)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_SSE2") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_SSE2")
endif(USE_SSE2) endif(USE_SSE2)
endif(USE_SSE) endif(USE_AVX)
include_directories(./include/) include_directories(./include/)

View File

@@ -5,6 +5,7 @@
#include <emmintrin.h> #include <emmintrin.h>
#include <xmmintrin.h> #include <xmmintrin.h>
#include <pmmintrin.h> #include <pmmintrin.h>
#include <immintrin.h>
#include "./StreamingBasisFunction.h" #include "./StreamingBasisFunction.h"
@@ -18,8 +19,37 @@ namespace BasisFunction {
Linear() {} Linear() {}
inline virtual float operator()(const std::vector<float>& weights, const std::vector<float>& input) const override { inline virtual float operator()(const std::vector<float>& weights, const std::vector<float>& input) const override {
#ifdef USE_AVX
//TODO: check sizes!!!
std::size_t inputSize=input.size();
size_t alignedPrev=inputSize-inputSize%8;
#ifdef USE_SSE const float* weightsData=weights.data();
const float* inputData=input.data();
union {
__m256 avx;
float f[8];
} partialSolution;
partialSolution.avx=_mm256_setzero_ps();
for(size_t k=0;k<alignedPrev;k+=8) {
//TODO: asignement!! -- possible speedup
partialSolution.avx=_mm256_add_ps(partialSolution.avx,_mm256_mul_ps(_mm256_loadu_ps(weightsData+k),_mm256_loadu_ps(inputData+k)));
}
for(size_t k=alignedPrev;k<inputSize;k++) {
partialSolution.avx=_mm256_add_ps(partialSolution.avx,_mm256_mul_ps(_mm256_set_ps(weightsData[k],0,0,0,0,0,0,0),_mm256_set_ps(inputData[k],0,0,0,0,0,0,0)));
}
partialSolution.avx = _mm256_add_ps(partialSolution.avx, _mm256_permute2f128_ps(partialSolution.avx , partialSolution.avx , 1));
partialSolution.avx = _mm256_hadd_ps(partialSolution.avx, partialSolution.avx);
partialSolution.avx = _mm256_hadd_ps(partialSolution.avx, partialSolution.avx);
return partialSolution.f[0];
#else
#ifdef USE_SSE
size_t inputSize=input.size(); size_t inputSize=input.size();
size_t alignedPrev=inputSize-inputSize%4; size_t alignedPrev=inputSize-inputSize%4;
@@ -45,7 +75,7 @@ namespace BasisFunction {
partialSolution.sse = _mm_hadd_ps(partialSolution.sse, partialSolution.sse); partialSolution.sse = _mm_hadd_ps(partialSolution.sse, partialSolution.sse);
#endif #endif
return partialSolution.f[0]; return partialSolution.f[0];
#else #else
register float tmp = 0; register float tmp = 0;
size_t inputSize=input.size(); size_t inputSize=input.size();
@@ -53,6 +83,7 @@ namespace BasisFunction {
tmp+=input[k]*weights[k]; tmp+=input[k]*weights[k];
} }
return tmp; return tmp;
#endif
#endif #endif
} }