new function to support LSTM Unit
This commit is contained in:
34
include/NeuralNetwork/BasisFunction/Product.h
Normal file
34
include/NeuralNetwork/BasisFunction/Product.h
Normal file
@@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
|
||||
#include "./BasisFunction.h"
|
||||
|
||||
namespace NeuralNetwork {
|
||||
namespace BasisFunction {
|
||||
|
||||
class Product: public BasisFunction {
|
||||
public:
|
||||
Product() {}
|
||||
|
||||
/**
|
||||
* @brief function computes product of inputs, where weight > 0.5
|
||||
*/
|
||||
inline virtual float operator()(const std::vector<float>& weights, const std::vector<float>& input) const override {
|
||||
float product=1.0;
|
||||
for(size_t i=0;i<weights.size();i++) {
|
||||
if(weights[i] > 0.5)
|
||||
product=product*input[i];
|
||||
}
|
||||
return product;
|
||||
}
|
||||
|
||||
virtual Product* clone() const override {
|
||||
return new Product();
|
||||
}
|
||||
|
||||
virtual std::string stringify() const override {
|
||||
return "{ \"class\": \"NeuralNetwork::BasisFunction::Product\" }";
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user