Files
NeuralNetworkLib/include/NeuralNetwork/BasisFunction/Radial.h

37 lines
917 B
C++

#pragma once
#include "./BasisFunction.h"
namespace NeuralNetwork
{
namespace BasisFunction
{
class Radial: public BasisFunction
{
public:
Radial() {}
virtual float operator()(const std::vector<float>& weights, const std::vector<float>& input) const override {
float sum = 0.0;
for(std::size_t i=0;i<weights.size();i++) {
sum+=pow(input[i]-weights[i],2);
}
return sqrt(sum);
}
virtual std::unique_ptr<BasisFunction> clone() const override {
return std::unique_ptr<BasisFunction>(new Radial());
}
virtual SimpleJSON::Type::Object serialize() const override {
return {{"class", "NeuralNetwork::BasisFunction::Radial"}};
}
static std::unique_ptr<Radial> deserialize(const SimpleJSON::Type::Object &) {
return std::unique_ptr<Radial>(new Radial());
}
NEURAL_NETWORK_REGISTER_BASIS_FUNCTION(NeuralNetwork::BasisFunction::Radial, deserialize)
};
}
}