324 lines
7.9 KiB
C++
324 lines
7.9 KiB
C++
#pragma once
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include <sstream>
|
|
#include <limits>
|
|
|
|
#include <NeuralNetwork/ActivationFunction/Sigmoid.h>
|
|
#include <NeuralNetwork/BasisFunction/Linear.h>
|
|
|
|
namespace NeuralNetwork
|
|
{
|
|
/**
|
|
* @author Tomas Cernik (Tom.Cernik@gmail.com)
|
|
* @brief Abstract class of neuron. All Neuron classes should derive from this on
|
|
*/
|
|
class NeuronInterface
|
|
{
|
|
public:
|
|
|
|
/**
|
|
* @brief returns unique id for neuron
|
|
*/
|
|
virtual unsigned long id() const =0;
|
|
|
|
/**
|
|
* @brief virtual destructor for Neuron
|
|
*/
|
|
virtual ~NeuronInterface() {};
|
|
|
|
/**
|
|
* @brief This is a virtual function for storing network
|
|
* @returns json describing network and it's state
|
|
*/
|
|
virtual std::string stringify(const std::string &prefix="") const =0;
|
|
|
|
/**
|
|
* @brief Gets weight
|
|
* @param n is neuron
|
|
*/
|
|
virtual float getWeight(const NeuronInterface &n) const =0;
|
|
|
|
/**
|
|
* @brief Sets weight
|
|
* @param n is neuron
|
|
* @param w is new weight for input neuron n
|
|
*/
|
|
virtual void setWeight(const NeuronInterface& n ,const float &w) =0;
|
|
|
|
/**
|
|
* @brief Returns output of neuron
|
|
*/
|
|
virtual float output() const =0;
|
|
|
|
/**
|
|
* @brief Returns input of neuron
|
|
*/
|
|
virtual float value() const=0;
|
|
|
|
/**
|
|
* @brief Returns value for derivation of activation function
|
|
*/
|
|
// virtual float derivatedOutput() const=0;
|
|
|
|
/**
|
|
* @brief Function sets bias for neuron
|
|
* @param bias is new bias (initial value for neuron)
|
|
*/
|
|
virtual void setBias(const float &bias)=0;
|
|
|
|
/**
|
|
* @brief Function returns bias for neuron
|
|
*/
|
|
virtual float getBias() const=0;
|
|
|
|
virtual float operator()(const std::vector<float>& inputs) =0;
|
|
|
|
virtual void setInputSize(const std::size_t &size) = 0;
|
|
|
|
/**
|
|
* @brief Function returns clone of object
|
|
*/
|
|
virtual NeuronInterface* clone() const = 0;
|
|
|
|
/*
|
|
* @brief getter for basis function of neuron
|
|
*/
|
|
virtual BasisFunction::BasisFunction& getBasisFunction() =0;
|
|
|
|
/*
|
|
* @brief getter for activation function of neuron
|
|
*/
|
|
virtual ActivationFunction::ActivationFunction& getActivationFunction() =0;
|
|
};
|
|
|
|
/**
|
|
* @author Tomas Cernik (Tom.Cernik@gmail.com)
|
|
* @brief Class of FeedForward neuron.
|
|
*/
|
|
class Neuron: public NeuronInterface
|
|
{
|
|
public:
|
|
Neuron(unsigned long _id=0, const ActivationFunction::ActivationFunction &activationFunction=ActivationFunction::Sigmoid(-4.9)):
|
|
NeuronInterface(), basis(new BasisFunction::Linear),
|
|
activation(activationFunction.clone()),
|
|
id_(_id),weights(_id+1),_output(0),_value(0) {
|
|
}
|
|
|
|
Neuron(const Neuron &r): NeuronInterface(), basis(r.basis->clone()), activation(r.activation->clone()),id_(r.id_),
|
|
weights(r.weights), _output(r._output), _value(r._value) {
|
|
}
|
|
|
|
virtual ~Neuron() {
|
|
delete basis;
|
|
delete activation;
|
|
};
|
|
|
|
virtual std::string stringify(const std::string &prefix="") const override;
|
|
|
|
Neuron& operator=(const Neuron&r) {
|
|
id_=r.id_;
|
|
weights=r.weights;
|
|
basis=r.basis->clone();
|
|
activation=r.activation->clone();
|
|
return *this;
|
|
}
|
|
|
|
virtual long unsigned int id() const override {
|
|
return id_;
|
|
};
|
|
|
|
/**
|
|
* @brief Gets weight
|
|
* @param n is neuron
|
|
*/
|
|
virtual float getWeight(const NeuronInterface &n) const override {
|
|
return weights[n.id()];
|
|
}
|
|
|
|
/**
|
|
* @brief Sets weight
|
|
* @param n is neuron
|
|
* @param w is new weight for input neuron n
|
|
*/
|
|
virtual void setWeight(const NeuronInterface& n ,const float &w) override {
|
|
if(weights.size()<n.id()+1) {
|
|
weights.resize(n.id()+1);
|
|
}
|
|
weights[n.id()]=w;
|
|
}
|
|
|
|
virtual void setInputSize(const std::size_t &size) override {
|
|
if(weights.size()<size+1) {
|
|
weights.resize(size+1);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief Returns output of neuron
|
|
*/
|
|
virtual float output() const override {
|
|
return _output;
|
|
}
|
|
|
|
/**
|
|
* @brief Returns input of neuron
|
|
*/
|
|
virtual float value() const override {
|
|
return _value;
|
|
}
|
|
|
|
/**
|
|
* @brief Function sets bias for neuron
|
|
* @param _bias is new bias (initial value for neuron)
|
|
*/
|
|
virtual void setBias(const float &_bias) override {
|
|
weights[0]=_bias;
|
|
}
|
|
|
|
/**
|
|
* @brief Function returns bias for neuron
|
|
*/
|
|
virtual float getBias() const override {
|
|
return weights[0];
|
|
}
|
|
|
|
float operator()(const std::vector<float>& inputs) {
|
|
//compute value
|
|
_value=basis->operator()(weights,inputs);
|
|
|
|
//compute output
|
|
_output=activation->operator()(_value);
|
|
|
|
return _output;
|
|
}
|
|
|
|
virtual Neuron* clone() const override {
|
|
Neuron *n = new Neuron;
|
|
*n=*this;
|
|
return n;
|
|
}
|
|
|
|
virtual BasisFunction::BasisFunction& getBasisFunction() override {
|
|
return *basis;
|
|
}
|
|
|
|
virtual ActivationFunction::ActivationFunction& getActivationFunction() override {
|
|
return *activation;
|
|
}
|
|
|
|
protected:
|
|
|
|
BasisFunction::BasisFunction *basis;
|
|
ActivationFunction::ActivationFunction *activation;
|
|
|
|
unsigned long id_;
|
|
std::vector<float> weights;
|
|
|
|
float _output;
|
|
float _value;
|
|
};
|
|
|
|
class BiasNeuron: public NeuronInterface {
|
|
public:
|
|
class usageException : public std::exception {
|
|
public:
|
|
usageException(const std::string &text_):text(text_) {
|
|
|
|
}
|
|
|
|
virtual const char* what() const noexcept override {
|
|
return text.c_str();
|
|
}
|
|
protected:
|
|
std::string text;
|
|
};
|
|
|
|
virtual float getBias() const override { return 0; };
|
|
|
|
virtual float getWeight(const NeuronInterface&) const override { return 0; }
|
|
|
|
virtual void setBias(const float&) override{ }
|
|
|
|
virtual float output() const override { return 1.0; };
|
|
|
|
virtual void setWeight(const NeuronInterface&, const float&) override { }
|
|
|
|
virtual std::string stringify(const std::string& prefix = "") const override { return prefix+"{ \"class\" : \"NeuralNetwork::BiasNeuron\" }"; }
|
|
|
|
virtual float value() const override { return 1.0; }
|
|
|
|
virtual long unsigned int id() const override { return 0; }
|
|
|
|
virtual float operator()(const std::vector< float >&) override { return 1.0; }
|
|
|
|
virtual void setInputSize(const std::size_t&) override {
|
|
}
|
|
|
|
virtual BiasNeuron* clone() const { return new BiasNeuron(); }
|
|
|
|
virtual BasisFunction::BasisFunction& getBasisFunction() override {
|
|
throw usageException("basis function");
|
|
}
|
|
|
|
virtual ActivationFunction::ActivationFunction& getActivationFunction() override {
|
|
throw usageException("activation function");
|
|
}
|
|
|
|
};
|
|
|
|
class InputNeuron: public NeuronInterface {
|
|
public:
|
|
class usageException : public std::exception {
|
|
public:
|
|
usageException(const std::string &text_):text(text_) {
|
|
|
|
}
|
|
|
|
virtual const char* what() const noexcept override {
|
|
return text.c_str();
|
|
}
|
|
protected:
|
|
std::string text;
|
|
};
|
|
|
|
InputNeuron(long unsigned int _id): id_(_id) {
|
|
|
|
}
|
|
|
|
virtual float getBias() const override { return 0; };
|
|
|
|
virtual float getWeight(const NeuronInterface&) const override { return 0; }
|
|
|
|
virtual void setBias(const float&) override{ }
|
|
|
|
virtual float output() const override { return 1.0; };
|
|
|
|
virtual void setWeight(const NeuronInterface&, const float&) override { }
|
|
|
|
virtual std::string stringify(const std::string& prefix = "") const override { return prefix+"{ \"class\" : \"NeuralNetwork::InputNeuron\", \"id\": "+std::to_string(id_)+" }"; }
|
|
|
|
virtual float value() const override { return 1.0; }
|
|
|
|
virtual long unsigned int id() const override { return id_; }
|
|
|
|
virtual float operator()(const std::vector< float >&) override { return 1.0; }
|
|
|
|
virtual void setInputSize(const std::size_t&) override {
|
|
}
|
|
|
|
virtual InputNeuron* clone() const { return new InputNeuron(id_); }
|
|
|
|
virtual BasisFunction::BasisFunction& getBasisFunction() override {
|
|
throw usageException("basis function");
|
|
}
|
|
|
|
virtual ActivationFunction::ActivationFunction& getActivationFunction() override {
|
|
throw usageException("activation function");
|
|
}
|
|
protected:
|
|
long unsigned int id_;
|
|
};
|
|
} |