basis functions

This commit is contained in:
2015-08-29 22:09:20 +02:00
parent 1f431ac27f
commit 87996d1a86
4 changed files with 120 additions and 0 deletions

View File

@@ -0,0 +1,18 @@
#ifndef _BASIS_FUN_H_
#define _BASIS_FUN_H_
#include <math.h>
namespace NeuralNetwork
{
namespace BasisFunction
{
class BasisFunction
{
public:
virtual ~BasisFunction() {}
virtual float operator()(const size_t &inputSize, const float* weights, const float* input)=0;
};
}
}
#endif

View File

@@ -0,0 +1,63 @@
#ifndef __BASIS_FEEDFORWARD_H_
#define __BASIS_FEEDFORWARD_H_
#include "./StreamingBasisFunction.h"
#include <mmintrin.h>
#include <xmmintrin.h>
#include <emmintrin.h>
#include <xmmintrin.h>
#include <pmmintrin.h>
namespace NeuralNetwork
{
namespace BasisFunction
{
class FeedForward: public StreamingBasisFunction
{
public:
FeedForward() {}
inline virtual __m128 operator()(const size_t& inputSize, const float* weights, const float* input, const size_t& alignedPrev)
{
__m128 partialSolution= _mm_setzero_ps();
__m128 w=_mm_setzero_ps();
__m128 sols;
for(register size_t k=alignedPrev;k<inputSize;k++)
{
w = _mm_load_ss(weights+k);
sols = _mm_load_ss(input+k);
w=_mm_mul_ps(w,sols);
partialSolution=_mm_add_ps(partialSolution,w);
}
for(register size_t k=0;k<alignedPrev;k+=sizeof(float)) // TODO ??? sizeof(float)
{
w = _mm_load_ps(weights+k);
sols = _mm_load_ps(input+k);
w=_mm_mul_ps(w,sols);
partialSolution=_mm_add_ps(partialSolution,w);
}
#ifdef USE_SSE2 //pre-SSE3 solution
partialSolution= _mm_add_ps(_mm_movehl_ps(partialSolution, partialSolution), partialSolution);
partialSolution=_mm_add_ss(partialSolution, _mm_shuffle_ps(partialSolution,partialSolution, 1));
#else
partialSolution = _mm_hadd_ps(partialSolution, partialSolution);
partialSolution = _mm_hadd_ps(partialSolution, partialSolution);
#endif
return partialSolution;
}
inline virtual float operator()(const size_t &inputSize, const float* weights, const float* input)
{
register float tmp = 0;
for(register size_t k=0;k<inputSize;k++)
{
tmp+=input[k]*weights[k];
}
return tmp;
}
};
}
}
#endif

View File

@@ -0,0 +1,17 @@
#ifndef __BASIS_RADIAL_H_
#define __BASIS_RADIAL_H_
#include "./BasisFunction.h"
namespace NeuralNetwork
{
namespace BasisFunction
{
class Radial: public BasisFunction
{
public:
Radial() {}
};
}
}
#endif

View File

@@ -0,0 +1,22 @@
#ifndef __STREAMINGBASIS_FUN_H_
#define __STREAMINGBASIS_FUN_H_
#include <xmmintrin.h>
#include "../../sse_mathfun.h"
#include "./BasisFunction.h"
namespace NeuralNetwork
{
namespace BasisFunction
{
class StreamingBasisFunction : public BasisFunction
{
public:
virtual float operator()(const size_t &inputSize, const float* weights, const float* input) = 0;
virtual __m128 operator()(const size_t& inputSize, const float* weights, const float* input, const size_t& alignedPrev) =0;
};
}
}
#endif