#pragma once #include #include #include #include #include #include namespace NeuralNetwork { /** * @author Tomas Cernik (Tom.Cernik@gmail.com) * @brief Abstract class of neuron. All Neuron classes should derive from this on */ class NeuronInterface { public: /** * @brief returns unique id for neuron */ virtual unsigned long id() const =0; /** * @brief virtual destructor for Neuron */ virtual ~NeuronInterface() {}; /** * @brief This is a virtual function for storing network * @returns json describing network and it's state */ virtual std::string stringify(const std::string &prefix="") const =0; /** * @brief Gets weight * @param n is neuron */ virtual float getWeight(const NeuronInterface &n) const =0; /** * @brief Sets weight * @param n is neuron * @param w is new weight for input neuron n */ virtual void setWeight(const NeuronInterface& n ,const float &w) =0; /** * @brief Returns output of neuron */ virtual float output() const =0; /** * @brief Returns input of neuron */ virtual float value() const=0; /** * @brief Returns value for derivation of activation function */ // virtual float derivatedOutput() const=0; /** * @brief Function sets bias for neuron * @param bias is new bias (initial value for neuron) */ virtual void setBias(const float &bias)=0; /** * @brief Function returns bias for neuron */ virtual float getBias() const=0; virtual float operator()(const std::vector& inputs) =0; virtual void setInputSize(const std::size_t &size) = 0; /** * @brief Function returns clone of object */ virtual NeuronInterface* clone() const = 0; /* * @brief getter for basis function of neuron */ virtual BasisFunction::BasisFunction& getBasisFunction() =0; /* * @brief getter for activation function of neuron */ virtual ActivationFunction::ActivationFunction& getActivationFunction() =0; }; /** * @author Tomas Cernik (Tom.Cernik@gmail.com) * @brief Class of FeedForward neuron. */ class Neuron: public NeuronInterface { public: Neuron(unsigned long _id=0, const ActivationFunction::ActivationFunction &activationFunction=ActivationFunction::Sigmoid(-4.9)): NeuronInterface(), basis(new BasisFunction::Linear), activation(activationFunction.clone()), id_(_id),weights(_id+1),_output(0),_value(0) { } Neuron(const Neuron &r): NeuronInterface(), basis(r.basis->clone()), activation(r.activation->clone()),id_(r.id_), weights(r.weights), _output(r._output), _value(r._value) { } virtual ~Neuron() { delete basis; delete activation; }; virtual std::string stringify(const std::string &prefix="") const override; Neuron& operator=(const Neuron&r) { id_=r.id_; weights=r.weights; basis=r.basis->clone(); activation=r.activation->clone(); return *this; } virtual long unsigned int id() const override { return id_; }; /** * @brief Gets weight * @param n is neuron */ virtual float getWeight(const NeuronInterface &n) const override { return weights[n.id()]; } /** * @brief Sets weight * @param n is neuron * @param w is new weight for input neuron n */ virtual void setWeight(const NeuronInterface& n ,const float &w) override { if(weights.size()& inputs) { //compute value _value=basis->operator()(weights,inputs); //compute output _output=activation->operator()(_value); return _output; } virtual Neuron* clone() const override { Neuron *n = new Neuron; *n=*this; return n; } virtual BasisFunction::BasisFunction& getBasisFunction() override { return *basis; } virtual ActivationFunction::ActivationFunction& getActivationFunction() override { return *activation; } protected: BasisFunction::BasisFunction *basis; ActivationFunction::ActivationFunction *activation; unsigned long id_; std::vector weights; float _output; float _value; }; class BiasNeuron: public NeuronInterface { public: class usageException : public std::exception { public: usageException(const std::string &text_):text(text_) { } virtual const char* what() const noexcept override { return text.c_str(); } protected: std::string text; }; virtual float getBias() const override { return 0; }; virtual float getWeight(const NeuronInterface&) const override { return 0; } virtual void setBias(const float&) override{ } virtual float output() const override { return 1.0; }; virtual void setWeight(const NeuronInterface&, const float&) override { } virtual std::string stringify(const std::string& prefix = "") const override { return prefix+"{ \"class\" : \"NeuralNetwork::BiasNeuron\" }"; } virtual float value() const override { return 1.0; } virtual long unsigned int id() const override { return 0; } virtual float operator()(const std::vector< float >&) override { return 1.0; } virtual void setInputSize(const std::size_t&) override { } virtual BiasNeuron* clone() const { return new BiasNeuron(); } virtual BasisFunction::BasisFunction& getBasisFunction() override { throw usageException("basis function"); } virtual ActivationFunction::ActivationFunction& getActivationFunction() override { throw usageException("activation function"); } }; class InputNeuron: public NeuronInterface { public: class usageException : public std::exception { public: usageException(const std::string &text_):text(text_) { } virtual const char* what() const noexcept override { return text.c_str(); } protected: std::string text; }; InputNeuron(long unsigned int _id): id_(_id) { } virtual float getBias() const override { return 0; }; virtual float getWeight(const NeuronInterface&) const override { return 0; } virtual void setBias(const float&) override{ } virtual float output() const override { return 1.0; }; virtual void setWeight(const NeuronInterface&, const float&) override { } virtual std::string stringify(const std::string& prefix = "") const override { return prefix+"{ \"class\" : \"NeuralNetwork::InputNeuron\", \"id\": "+std::to_string(id_)+" }"; } virtual float value() const override { return 1.0; } virtual long unsigned int id() const override { return id_; } virtual float operator()(const std::vector< float >&) override { return 1.0; } virtual void setInputSize(const std::size_t&) override { } virtual InputNeuron* clone() const { return new InputNeuron(id_); } virtual BasisFunction::BasisFunction& getBasisFunction() override { throw usageException("basis function"); } virtual ActivationFunction::ActivationFunction& getActivationFunction() override { throw usageException("activation function"); } protected: long unsigned int id_; }; }