Files
NeuralNetworkLib/include/NeuralNetwork/BasisFunction/Product.h

39 lines
1.0 KiB
C++

#pragma once
#include "./BasisFunction.h"
namespace NeuralNetwork {
namespace BasisFunction {
class Product: public BasisFunction {
public:
Product() {}
/**
* @brief function computes product of inputs, where weight > 0.5
*/
inline virtual float operator()(const std::vector<float>& weights, const std::vector<float>& input) const override {
float product=1.0;
for(size_t i=0;i<weights.size();i++) {
if(weights[i] > 0.5)
product=product*input[i];
}
return product;
}
virtual std::unique_ptr<BasisFunction> clone() const override {
return std::unique_ptr<BasisFunction>(new Product());
}
virtual SimpleJSON::Type::Object serialize() const override {
return {{"class", "NeuralNetwork::BasisFunction::Product"}};
}
static std::unique_ptr<Product> deserialize(const SimpleJSON::Type::Object &) {
return std::unique_ptr<Product>(new Product());
}
NEURAL_NETWORK_REGISTER_BASIS_FUNCTION(NeuralNetwork::BasisFunction::Product, deserialize)
};
}
}