serializatioin / deserialization and tests
This commit is contained in:
@@ -1,22 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include <NeuralNetwork/ActivationFunction/Sigmoid.h>
|
||||
#include <NeuralNetwork/BasisFunction/Linear.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <sstream>
|
||||
#include <limits>
|
||||
|
||||
#include <NeuralNetwork/ActivationFunction/Sigmoid.h>
|
||||
#include <NeuralNetwork/BasisFunction/Linear.h>
|
||||
|
||||
namespace NeuralNetwork
|
||||
{
|
||||
/**
|
||||
* @author Tomas Cernik (Tom.Cernik@gmail.com)
|
||||
* @brief Abstract class of neuron. All Neuron classes should derive from this on
|
||||
*/
|
||||
class NeuronInterface
|
||||
{
|
||||
class NeuronInterface : public SimpleJSON::SerializableObject {
|
||||
public:
|
||||
NeuronInterface(const unsigned long &_id=0): id(_id), weights(1),_output(1),
|
||||
_value(0) {
|
||||
@@ -33,12 +31,6 @@ namespace NeuralNetwork
|
||||
*/
|
||||
virtual ~NeuronInterface() {};
|
||||
|
||||
/**
|
||||
* @brief This is a virtual function for storing network
|
||||
* @returns json describing network and it's state
|
||||
*/
|
||||
virtual std::string stringify(const std::string &prefix="") const =0;
|
||||
|
||||
/**
|
||||
* @brief getter for neuron weight
|
||||
* @param &neuron is neuron it's weight is returned
|
||||
@@ -118,6 +110,8 @@ namespace NeuralNetwork
|
||||
* @brief id is identificator if neuron
|
||||
*/
|
||||
const unsigned long id;
|
||||
|
||||
typedef SimpleJSON::Factory<NeuronInterface> Factory;
|
||||
protected:
|
||||
std::vector<float> weights;
|
||||
float _output;
|
||||
@@ -128,8 +122,7 @@ namespace NeuralNetwork
|
||||
* @author Tomas Cernik (Tom.Cernik@gmail.com)
|
||||
* @brief Class of FeedForward neuron.
|
||||
*/
|
||||
class Neuron: public NeuronInterface
|
||||
{
|
||||
class Neuron: public NeuronInterface {
|
||||
public:
|
||||
Neuron(unsigned long _id=0, const ActivationFunction::ActivationFunction &activationFunction=ActivationFunction::Sigmoid(-4.9)):
|
||||
NeuronInterface(_id), basis(new BasisFunction::Linear),
|
||||
@@ -137,7 +130,7 @@ namespace NeuralNetwork
|
||||
_output=0.0;
|
||||
}
|
||||
|
||||
Neuron(const Neuron &r): NeuronInterface(r), basis(r.basis->clone()), activation(r.activation->clone()) {
|
||||
Neuron(const Neuron &r): NeuronInterface(r), basis(r.basis->clone().release()), activation(r.activation->clone()) {
|
||||
}
|
||||
|
||||
virtual ~Neuron() {
|
||||
@@ -147,8 +140,6 @@ namespace NeuralNetwork
|
||||
|
||||
Neuron operator=(const Neuron&) = delete;
|
||||
|
||||
virtual std::string stringify(const std::string &prefix="") const override;
|
||||
|
||||
float operator()(const std::vector<float>& inputs) {
|
||||
//compute value
|
||||
_value=basis->operator()(weights,inputs);
|
||||
@@ -174,7 +165,7 @@ namespace NeuralNetwork
|
||||
|
||||
virtual void setBasisFunction(const BasisFunction::BasisFunction& basisFunction) override {
|
||||
delete basis;
|
||||
basis=basisFunction.clone();
|
||||
basis=basisFunction.clone().release();
|
||||
|
||||
}
|
||||
|
||||
@@ -182,10 +173,16 @@ namespace NeuralNetwork
|
||||
delete activation;
|
||||
activation=activationFunction.clone();
|
||||
}
|
||||
|
||||
virtual SimpleJSON::Type::Object serialize() const override;
|
||||
|
||||
static std::unique_ptr<Neuron> deserialize(const SimpleJSON::Type::Object &obj);
|
||||
protected:
|
||||
|
||||
BasisFunction::BasisFunction *basis;
|
||||
ActivationFunction::ActivationFunction *activation;
|
||||
|
||||
SIMPLEJSON_REGISTER(NeuralNetwork::NeuronInterface::Factory, NeuralNetwork::Neuron,NeuralNetwork::Neuron::deserialize)
|
||||
};
|
||||
|
||||
class BiasNeuron: public NeuronInterface {
|
||||
@@ -203,8 +200,6 @@ namespace NeuralNetwork
|
||||
std::string text;
|
||||
};
|
||||
|
||||
virtual std::string stringify(const std::string& prefix = "") const override { return prefix+"{ \"class\" : \"NeuralNetwork::BiasNeuron\" }"; }
|
||||
|
||||
virtual float operator()(const std::vector< float >&) override { return 1.0; }
|
||||
|
||||
virtual BiasNeuron* clone() const { return new BiasNeuron(); }
|
||||
@@ -226,6 +221,15 @@ namespace NeuralNetwork
|
||||
throw usageException("activation function");
|
||||
}
|
||||
|
||||
virtual SimpleJSON::Type::Object serialize() const override {
|
||||
return {{"class", "NeuralNetwork::BiasNeuron"}};
|
||||
}
|
||||
|
||||
static std::unique_ptr<BiasNeuron> deserialize(const SimpleJSON::Type::Object &) {
|
||||
return std::unique_ptr<BiasNeuron>(new BiasNeuron());
|
||||
}
|
||||
|
||||
SIMPLEJSON_REGISTER(NeuralNetwork::NeuronInterface::Factory, NeuralNetwork::BiasNeuron,NeuralNetwork::BiasNeuron::deserialize)
|
||||
};
|
||||
|
||||
class InputNeuron: public NeuronInterface {
|
||||
@@ -247,8 +251,6 @@ namespace NeuralNetwork
|
||||
|
||||
}
|
||||
|
||||
virtual std::string stringify(const std::string& prefix = "") const override { return prefix+"{ \"class\" : \"NeuralNetwork::InputNeuron\", \"id\": "+std::to_string(id)+" }"; }
|
||||
|
||||
virtual float operator()(const std::vector< float >&) override { return 1.0; }
|
||||
|
||||
virtual InputNeuron* clone() const { return new InputNeuron(id); }
|
||||
@@ -269,5 +271,15 @@ namespace NeuralNetwork
|
||||
virtual void setActivationFunction(const ActivationFunction::ActivationFunction &) override {
|
||||
throw usageException("activation function");
|
||||
}
|
||||
|
||||
virtual SimpleJSON::Type::Object serialize() const override {
|
||||
return {{"class", "NeuralNetwork::InputNeuron"}, {"id", id}};
|
||||
}
|
||||
|
||||
static std::unique_ptr<NeuronInterface> deserialize(const SimpleJSON::Type::Object &obj) {
|
||||
return std::unique_ptr<NeuronInterface>(new InputNeuron(obj["id"].as<int>()));
|
||||
}
|
||||
|
||||
SIMPLEJSON_REGISTER(NeuralNetwork::NeuronInterface::Factory, NeuralNetwork::InputNeuron,NeuralNetwork::InputNeuron::deserialize)
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user