new function to support LSTM Unit

This commit is contained in:
2016-01-27 23:40:32 +01:00
parent d424d87535
commit 3c26c9641c
7 changed files with 97 additions and 14 deletions

View File

@@ -10,7 +10,7 @@ namespace BasisFunction {
class BasisFunction {
public:
virtual ~BasisFunction() {}
virtual float operator()(const std::vector<float>& weights, const std::vector<float>& input)=0;
virtual float operator()(const std::vector<float>& weights, const std::vector<float>& input) const =0;
/**
* @brief Function returns clone of object

View File

@@ -17,7 +17,7 @@ namespace BasisFunction {
public:
Linear() {}
inline virtual float computeStreaming(const std::vector<float>& weights, const std::vector<float>& input) override {
inline virtual float computeStreaming(const std::vector<float>& weights, const std::vector<float>& input) const override {
size_t inputSize=input.size();
size_t alignedPrev=inputSize-inputSize%4;
@@ -46,7 +46,7 @@ namespace BasisFunction {
return partialSolution.f[0];
}
inline virtual float compute(const std::vector<float>& weights, const std::vector<float>& input) override {
inline virtual float compute(const std::vector<float>& weights, const std::vector<float>& input) const override {
register float tmp = 0;
size_t inputSize=input.size();
for(size_t k=0;k<inputSize;k++) {

View File

@@ -0,0 +1,34 @@
#pragma once
#include "./BasisFunction.h"
namespace NeuralNetwork {
namespace BasisFunction {
class Product: public BasisFunction {
public:
Product() {}
/**
* @brief function computes product of inputs, where weight > 0.5
*/
inline virtual float operator()(const std::vector<float>& weights, const std::vector<float>& input) const override {
float product=1.0;
for(size_t i=0;i<weights.size();i++) {
if(weights[i] > 0.5)
product=product*input[i];
}
return product;
}
virtual Product* clone() const override {
return new Product();
}
virtual std::string stringify() const override {
return "{ \"class\": \"NeuralNetwork::BasisFunction::Product\" }";
}
};
}
}

View File

@@ -13,11 +13,13 @@ namespace BasisFunction {
float f[4];
};
virtual float operator()(const std::vector<float>& weights, const std::vector<float>& input) override {
virtual float operator()(const std::vector<float>& weights, const std::vector<float>& input) const override {
return computeStreaming(weights,input);
}
virtual float computeStreaming(const std::vector<float>& weights, const std::vector<float>& input) =0;
virtual float compute(const std::vector<float>& weights, const std::vector<float>& input) =0;
virtual float computeStreaming(const std::vector<float>& weights, const std::vector<float>& input) const =0;
virtual float compute(const std::vector<float>& weights, const std::vector<float>& input) const =0;
};
}
}