Files
NeuralNetworkLib/include/NeuralNetwork/Neuron.h
2016-02-18 18:53:59 +01:00

240 lines
6.0 KiB
C++

#pragma once
#include <string>
#include <vector>
#include <sstream>
#include <limits>
#include <NeuralNetwork/ActivationFunction/Sigmoid.h>
#include <NeuralNetwork/BasisFunction/Linear.h>
namespace NeuralNetwork
{
/**
* @author Tomas Cernik (Tom.Cernik@gmail.com)
* @brief Abstract class of neuron. All Neuron classes should derive from this on
*/
class NeuronInterface
{
public:
NeuronInterface(const unsigned long &_id=0): id(_id), weights(1),_output(1),
_value(0) {
}
NeuronInterface(const NeuronInterface &r): id(r.id), weights(r.weights),_output(r._output),
_value(r._value) {
weights=weights;
}
/**
* @brief virtual destructor for Neuron
*/
virtual ~NeuronInterface() {};
/**
* @brief This is a virtual function for storing network
* @returns json describing network and it's state
*/
virtual std::string stringify(const std::string &prefix="") const =0;
/**
* @brief getter for neuron weight
* @param &neuron is neuron it's weight is returned
*/
inline virtual float weight(const NeuronInterface &neuron) const final {
return weights[neuron.id];
}
/**
* @brief getter for neuron weight
* @param &neuronID is id of neuron
*/
inline virtual float weight(const std::size_t &neuronID) const final {
return weights[neuronID];
}
/**
* @brief This is a virtual function for storing network
* @returns json describing network and it's state
*/
inline virtual float& weight(const NeuronInterface &neuron) final {
return weights[neuron.id];
}
/**
* @brief getter for neuron weight
* @param &neuronID is id of neuron
*/
inline virtual float& weight(const std::size_t &neuronID) final {
return weights[neuronID];
}
/**
* @brief Returns output of neuron
*/
inline virtual float output() const final {
return _output;
}
/**
* @brief Returns input of neuron
*/
inline virtual float value() const final {
return _value;
};
virtual float operator()(const std::vector<float>& inputs) =0;
/**
* @brief function resizes weighs to desired size
*/
inline virtual void setInputSize(const std::size_t &size) final {
if(weights.size()<size) {
weights.resize(size);
}
}
/**
* @brief Function returns clone of object
*/
virtual NeuronInterface* clone() const = 0;
/**
* @brief getter for basis function of neuron
*/
virtual BasisFunction::BasisFunction& getBasisFunction() =0;
/**
* @brief getter for activation function of neuron
*/
virtual ActivationFunction::ActivationFunction& getActivationFunction() =0;
/**
* @brief id is identificator if neuron
*/
const unsigned long id;
protected:
std::vector<float> weights;
float _output;
float _value;
};
/**
* @author Tomas Cernik (Tom.Cernik@gmail.com)
* @brief Class of FeedForward neuron.
*/
class Neuron: public NeuronInterface
{
public:
Neuron(unsigned long _id=0, const ActivationFunction::ActivationFunction &activationFunction=ActivationFunction::Sigmoid(-4.9)):
NeuronInterface(_id), basis(new BasisFunction::Linear),
activation(activationFunction.clone()) {
_output=0.0;
}
Neuron(const Neuron &r): NeuronInterface(r), basis(r.basis->clone()), activation(r.activation->clone()) {
}
virtual ~Neuron() {
delete basis;
delete activation;
};
Neuron operator=(const Neuron&) = delete;
virtual std::string stringify(const std::string &prefix="") const override;
float operator()(const std::vector<float>& inputs) {
//compute value
_value=basis->operator()(weights,inputs);
//compute output
_output=activation->operator()(_value);
return _output;
}
virtual Neuron* clone() const override {
Neuron *n = new Neuron(*this);
return n;
}
virtual BasisFunction::BasisFunction& getBasisFunction() override {
return *basis;
}
virtual ActivationFunction::ActivationFunction& getActivationFunction() override {
return *activation;
}
protected:
BasisFunction::BasisFunction *basis;
ActivationFunction::ActivationFunction *activation;
};
class BiasNeuron: public NeuronInterface {
public:
class usageException : public std::exception {
public:
usageException(const std::string &text_):text(text_) {
}
virtual const char* what() const noexcept override {
return text.c_str();
}
protected:
std::string text;
};
virtual std::string stringify(const std::string& prefix = "") const override { return prefix+"{ \"class\" : \"NeuralNetwork::BiasNeuron\" }"; }
virtual float operator()(const std::vector< float >&) override { return 1.0; }
virtual BiasNeuron* clone() const { return new BiasNeuron(); }
virtual BasisFunction::BasisFunction& getBasisFunction() override {
throw usageException("basis function");
}
virtual ActivationFunction::ActivationFunction& getActivationFunction() override {
throw usageException("activation function");
}
};
class InputNeuron: public NeuronInterface {
public:
class usageException : public std::exception {
public:
usageException(const std::string &text_):text(text_) {
}
virtual const char* what() const noexcept override {
return text.c_str();
}
protected:
std::string text;
};
InputNeuron(long unsigned int _id): NeuronInterface(_id) {
}
virtual std::string stringify(const std::string& prefix = "") const override { return prefix+"{ \"class\" : \"NeuralNetwork::InputNeuron\", \"id\": "+std::to_string(id)+" }"; }
virtual float operator()(const std::vector< float >&) override { return 1.0; }
virtual InputNeuron* clone() const { return new InputNeuron(id); }
virtual BasisFunction::BasisFunction& getBasisFunction() override {
throw usageException("basis function");
}
virtual ActivationFunction::ActivationFunction& getActivationFunction() override {
throw usageException("activation function");
}
};
}