light optimalizations in SSE FFQ NN

This commit is contained in:
2014-11-15 11:21:07 +01:00
parent 9a2d7b85b1
commit 207e141cca
3 changed files with 22 additions and 19 deletions

View File

@@ -63,15 +63,25 @@ void FeedForwardNetworkQuick::solvePart(float *newSolution, register size_t begi
{
if(prevSize >8)
{
__m128 partialSolution;
__m128 partialSolution2;
__m128 w;
__m128 sols;
__m128 w2;
__m128 sols2;
__m128 temporaryConst1=_mm_set1_ps(1.0);
__m128 temporaryConstLambda=_mm_set1_ps(-lambda);
register size_t alignedPrev=prevSize>8?(prevSize-(prevSize%8)):0;
float tmp;
for( size_t j=begin;j<end;j++)
{
register size_t alignedPrev=prevSize>8?(prevSize-(prevSize%8)):0;
__m128 partialSolution = _mm_setzero_ps();
__m128 partialSolution2 = _mm_setzero_ps();
__m128 w;
__m128 sols;
__m128 w2;
__m128 sols2;
tmp=0;
for(register size_t k=alignedPrev;k<prevSize;k++)
{
tmp+=sol[k]*weights[layer][j][k];
}
partialSolution = _mm_setzero_ps();
partialSolution2 = _mm_set_ss(tmp);
for(register size_t k=0;k<alignedPrev;k+=8)
{
w = _mm_load_ps(this->weights[layer][j]+k);
@@ -87,17 +97,10 @@ void FeedForwardNetworkQuick::solvePart(float *newSolution, register size_t begi
partialSolution = _mm_hadd_ps(partialSolution, partialSolution);
partialSolution = _mm_hadd_ps(partialSolution, partialSolution);
_mm_store_ss(inputs[layer]+j,partialSolution);
for(register size_t k=alignedPrev;k<prevSize;k++)
{
inputs[layer][j]+=sol[k]*weights[layer][j][k];
}
partialSolution=_mm_load_ss(inputs[layer]+j);
__m128 temporaryConst = _mm_set1_ps(-lambda);
partialSolution=_mm_mul_ps(temporaryConst,partialSolution); //-lambda*sol[k]
partialSolution=_mm_mul_ps(temporaryConstLambda,partialSolution); //-lambda*sol[k]
partialSolution=exp_ps(partialSolution); //exp(sols)
temporaryConst = _mm_set1_ps(1.0);
partialSolution= _mm_add_ps(partialSolution,temporaryConst); //1+exp()
partialSolution= _mm_div_ps(temporaryConst,partialSolution);//1/....*/
partialSolution= _mm_add_ps(partialSolution,temporaryConst1); //1+exp()
partialSolution= _mm_div_ps(temporaryConst1,partialSolution);//1/....*/
_mm_store_ss(newSolution+j,partialSolution);
}
}else

View File

@@ -31,7 +31,7 @@ int main(int argc)
s.push_back(Shin::NeuronNetwork::Solution(std::vector<double>({0})));
p.push_back(X(std::vector<bool>({1})));
Shin::NeuronNetwork::FeedForwardNetworkQuick q({1,20000,20000,20000});
Shin::NeuronNetwork::FeedForwardNetworkQuick q({1,5000,5000,5000});
Shin::NeuronNetwork::Learning::BackPropagation b(q);
if(argc > 1)
{

View File

@@ -22,7 +22,7 @@ int main()
for (int test=0;test<2;test++)
{
Shin::NeuronNetwork::FeedForwardNetworkQuick q({2,4,1});
Shin::NeuronNetwork::FeedForwardNetworkQuick q({2,40,1});
Shin::NeuronNetwork::Learning::OpticalBackPropagation b(q);
srand(time(NULL));