added AVX for linear basis function
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -1,9 +1,3 @@
|
|||||||
NN.kdev4
|
|
||||||
.kdev4
|
|
||||||
*.o
|
|
||||||
*.a
|
|
||||||
*.so
|
|
||||||
*.nm
|
|
||||||
/doc/html/*
|
/doc/html/*
|
||||||
!/doc/html/doxy-boot.js
|
!/doc/html/doxy-boot.js
|
||||||
!/doc/html/jquery.powertip.min.js
|
!/doc/html/jquery.powertip.min.js
|
||||||
|
|||||||
@@ -2,22 +2,27 @@ cmake_minimum_required(VERSION 3.2)
|
|||||||
project(NeuralNetwork CXX)
|
project(NeuralNetwork CXX)
|
||||||
|
|
||||||
OPTION(BUILD_SHARED_LIBS "Build also shared library." ON)
|
OPTION(BUILD_SHARED_LIBS "Build also shared library." ON)
|
||||||
|
OPTION(USE_AVX "IF avx should be used." ON)
|
||||||
OPTION(USE_SSE "IF sse should be used." ON)
|
OPTION(USE_SSE "IF sse should be used." ON)
|
||||||
OPTION(USE_SSE2 "IF only sse2 should be used." OFF)
|
OPTION(USE_SSE2 "IF only sse2 should be used." OFF)
|
||||||
|
|
||||||
OPTION(ENABLE_TESTS "enables tests" ON)
|
OPTION(ENABLE_TESTS "enables tests" ON)
|
||||||
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Weffc++ -Wshadow -Wstrict-aliasing -ansi -Woverloaded-virtual -Wdelete-non-virtual-dtor -Wno-unused-function")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Weffc++ -Wshadow -Wstrict-aliasing -ansi -Woverloaded-virtual -Wdelete-non-virtual-dtor -Wno-unused-function")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14")
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g")
|
||||||
#set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -pthread")
|
#set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -pthread")
|
||||||
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native -O3")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native -O3")
|
||||||
|
|
||||||
if(USE_SSE)
|
if(USE_AVX)
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_AVX")
|
||||||
|
elseif(USE_SSE)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.2 -DUSE_SSE")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.2 -DUSE_SSE")
|
||||||
if(USE_SSE2)
|
if(USE_SSE2)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_SSE2")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_SSE2")
|
||||||
endif(USE_SSE2)
|
endif(USE_SSE2)
|
||||||
endif(USE_SSE)
|
endif(USE_AVX)
|
||||||
|
|
||||||
include_directories(./include/)
|
include_directories(./include/)
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include <emmintrin.h>
|
#include <emmintrin.h>
|
||||||
#include <xmmintrin.h>
|
#include <xmmintrin.h>
|
||||||
#include <pmmintrin.h>
|
#include <pmmintrin.h>
|
||||||
|
#include <immintrin.h>
|
||||||
|
|
||||||
#include "./StreamingBasisFunction.h"
|
#include "./StreamingBasisFunction.h"
|
||||||
|
|
||||||
@@ -18,8 +19,37 @@ namespace BasisFunction {
|
|||||||
Linear() {}
|
Linear() {}
|
||||||
|
|
||||||
inline virtual float operator()(const std::vector<float>& weights, const std::vector<float>& input) const override {
|
inline virtual float operator()(const std::vector<float>& weights, const std::vector<float>& input) const override {
|
||||||
|
#ifdef USE_AVX
|
||||||
|
//TODO: check sizes!!!
|
||||||
|
std::size_t inputSize=input.size();
|
||||||
|
size_t alignedPrev=inputSize-inputSize%8;
|
||||||
|
|
||||||
#ifdef USE_SSE
|
const float* weightsData=weights.data();
|
||||||
|
const float* inputData=input.data();
|
||||||
|
|
||||||
|
union {
|
||||||
|
__m256 avx;
|
||||||
|
float f[8];
|
||||||
|
} partialSolution;
|
||||||
|
|
||||||
|
partialSolution.avx=_mm256_setzero_ps();
|
||||||
|
|
||||||
|
for(size_t k=0;k<alignedPrev;k+=8) {
|
||||||
|
//TODO: asignement!! -- possible speedup
|
||||||
|
partialSolution.avx=_mm256_add_ps(partialSolution.avx,_mm256_mul_ps(_mm256_loadu_ps(weightsData+k),_mm256_loadu_ps(inputData+k)));
|
||||||
|
}
|
||||||
|
|
||||||
|
for(size_t k=alignedPrev;k<inputSize;k++) {
|
||||||
|
partialSolution.avx=_mm256_add_ps(partialSolution.avx,_mm256_mul_ps(_mm256_set_ps(weightsData[k],0,0,0,0,0,0,0),_mm256_set_ps(inputData[k],0,0,0,0,0,0,0)));
|
||||||
|
}
|
||||||
|
|
||||||
|
partialSolution.avx = _mm256_add_ps(partialSolution.avx, _mm256_permute2f128_ps(partialSolution.avx , partialSolution.avx , 1));
|
||||||
|
partialSolution.avx = _mm256_hadd_ps(partialSolution.avx, partialSolution.avx);
|
||||||
|
partialSolution.avx = _mm256_hadd_ps(partialSolution.avx, partialSolution.avx);
|
||||||
|
|
||||||
|
return partialSolution.f[0];
|
||||||
|
#else
|
||||||
|
#ifdef USE_SSE
|
||||||
size_t inputSize=input.size();
|
size_t inputSize=input.size();
|
||||||
size_t alignedPrev=inputSize-inputSize%4;
|
size_t alignedPrev=inputSize-inputSize%4;
|
||||||
|
|
||||||
@@ -37,15 +67,15 @@ namespace BasisFunction {
|
|||||||
partialSolution.sse=_mm_add_ps(partialSolution.sse,_mm_mul_ps(_mm_load_ss(weightsData+k),_mm_load_ss(inputData+k)));
|
partialSolution.sse=_mm_add_ps(partialSolution.sse,_mm_mul_ps(_mm_load_ss(weightsData+k),_mm_load_ss(inputData+k)));
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_SSE2 //pre-SSE3 solution
|
#ifdef USE_SSE2 //pre-SSE3 solution
|
||||||
partialSolution.sse= _mm_add_ps(_mm_movehl_ps(partialSolution.sse, partialSolution.sse), partialSolution.sse);
|
partialSolution.sse= _mm_add_ps(_mm_movehl_ps(partialSolution.sse, partialSolution.sse), partialSolution.sse);
|
||||||
partialSolution.sse=_mm_add_ss(partialSolution.sse, _mm_shuffle_ps(partialSolution.sse,partialSolution.sse, 1));
|
partialSolution.sse=_mm_add_ss(partialSolution.sse, _mm_shuffle_ps(partialSolution.sse,partialSolution.sse, 1));
|
||||||
#else
|
#else
|
||||||
partialSolution.sse = _mm_hadd_ps(partialSolution.sse, partialSolution.sse);
|
partialSolution.sse = _mm_hadd_ps(partialSolution.sse, partialSolution.sse);
|
||||||
partialSolution.sse = _mm_hadd_ps(partialSolution.sse, partialSolution.sse);
|
partialSolution.sse = _mm_hadd_ps(partialSolution.sse, partialSolution.sse);
|
||||||
#endif
|
#endif
|
||||||
return partialSolution.f[0];
|
return partialSolution.f[0];
|
||||||
#else
|
#else
|
||||||
|
|
||||||
register float tmp = 0;
|
register float tmp = 0;
|
||||||
size_t inputSize=input.size();
|
size_t inputSize=input.size();
|
||||||
@@ -53,6 +83,7 @@ namespace BasisFunction {
|
|||||||
tmp+=input[k]*weights[k];
|
tmp+=input[k]*weights[k];
|
||||||
}
|
}
|
||||||
return tmp;
|
return tmp;
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user