basis functions
This commit is contained in:
18
src/NeuralNetwork/BasisFunction/BasisFunction.h
Normal file
18
src/NeuralNetwork/BasisFunction/BasisFunction.h
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
#ifndef _BASIS_FUN_H_
|
||||||
|
#define _BASIS_FUN_H_
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
|
||||||
|
namespace NeuralNetwork
|
||||||
|
{
|
||||||
|
namespace BasisFunction
|
||||||
|
{
|
||||||
|
class BasisFunction
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual ~BasisFunction() {}
|
||||||
|
virtual float operator()(const size_t &inputSize, const float* weights, const float* input)=0;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
63
src/NeuralNetwork/BasisFunction/FeedForward.h
Normal file
63
src/NeuralNetwork/BasisFunction/FeedForward.h
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
#ifndef __BASIS_FEEDFORWARD_H_
|
||||||
|
#define __BASIS_FEEDFORWARD_H_
|
||||||
|
|
||||||
|
#include "./StreamingBasisFunction.h"
|
||||||
|
|
||||||
|
#include <mmintrin.h>
|
||||||
|
#include <xmmintrin.h>
|
||||||
|
#include <emmintrin.h>
|
||||||
|
#include <xmmintrin.h>
|
||||||
|
#include <pmmintrin.h>
|
||||||
|
|
||||||
|
namespace NeuralNetwork
|
||||||
|
{
|
||||||
|
namespace BasisFunction
|
||||||
|
{
|
||||||
|
class FeedForward: public StreamingBasisFunction
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
FeedForward() {}
|
||||||
|
|
||||||
|
inline virtual __m128 operator()(const size_t& inputSize, const float* weights, const float* input, const size_t& alignedPrev)
|
||||||
|
{
|
||||||
|
__m128 partialSolution= _mm_setzero_ps();
|
||||||
|
__m128 w=_mm_setzero_ps();
|
||||||
|
__m128 sols;
|
||||||
|
for(register size_t k=alignedPrev;k<inputSize;k++)
|
||||||
|
{
|
||||||
|
w = _mm_load_ss(weights+k);
|
||||||
|
sols = _mm_load_ss(input+k);
|
||||||
|
w=_mm_mul_ps(w,sols);
|
||||||
|
partialSolution=_mm_add_ps(partialSolution,w);
|
||||||
|
}
|
||||||
|
for(register size_t k=0;k<alignedPrev;k+=sizeof(float)) // TODO ??? sizeof(float)
|
||||||
|
{
|
||||||
|
w = _mm_load_ps(weights+k);
|
||||||
|
sols = _mm_load_ps(input+k);
|
||||||
|
w=_mm_mul_ps(w,sols);
|
||||||
|
partialSolution=_mm_add_ps(partialSolution,w);
|
||||||
|
}
|
||||||
|
#ifdef USE_SSE2 //pre-SSE3 solution
|
||||||
|
partialSolution= _mm_add_ps(_mm_movehl_ps(partialSolution, partialSolution), partialSolution);
|
||||||
|
partialSolution=_mm_add_ss(partialSolution, _mm_shuffle_ps(partialSolution,partialSolution, 1));
|
||||||
|
#else
|
||||||
|
partialSolution = _mm_hadd_ps(partialSolution, partialSolution);
|
||||||
|
partialSolution = _mm_hadd_ps(partialSolution, partialSolution);
|
||||||
|
#endif
|
||||||
|
return partialSolution;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline virtual float operator()(const size_t &inputSize, const float* weights, const float* input)
|
||||||
|
{
|
||||||
|
register float tmp = 0;
|
||||||
|
for(register size_t k=0;k<inputSize;k++)
|
||||||
|
{
|
||||||
|
tmp+=input[k]*weights[k];
|
||||||
|
}
|
||||||
|
return tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
17
src/NeuralNetwork/BasisFunction/Radial.h
Normal file
17
src/NeuralNetwork/BasisFunction/Radial.h
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
#ifndef __BASIS_RADIAL_H_
|
||||||
|
#define __BASIS_RADIAL_H_
|
||||||
|
|
||||||
|
#include "./BasisFunction.h"
|
||||||
|
|
||||||
|
namespace NeuralNetwork
|
||||||
|
{
|
||||||
|
namespace BasisFunction
|
||||||
|
{
|
||||||
|
class Radial: public BasisFunction
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
Radial() {}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
22
src/NeuralNetwork/BasisFunction/StreamingBasisFunction.h
Normal file
22
src/NeuralNetwork/BasisFunction/StreamingBasisFunction.h
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
#ifndef __STREAMINGBASIS_FUN_H_
|
||||||
|
#define __STREAMINGBASIS_FUN_H_
|
||||||
|
|
||||||
|
#include <xmmintrin.h>
|
||||||
|
|
||||||
|
#include "../../sse_mathfun.h"
|
||||||
|
|
||||||
|
#include "./BasisFunction.h"
|
||||||
|
|
||||||
|
namespace NeuralNetwork
|
||||||
|
{
|
||||||
|
namespace BasisFunction
|
||||||
|
{
|
||||||
|
class StreamingBasisFunction : public BasisFunction
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual float operator()(const size_t &inputSize, const float* weights, const float* input) = 0;
|
||||||
|
virtual __m128 operator()(const size_t& inputSize, const float* weights, const float* input, const size_t& alignedPrev) =0;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
Reference in New Issue
Block a user