#pragma once #include "../Network.h" #include "Layer.h" #include #include namespace NeuralNetwork { namespace FeedForward { /** * @author Tomas Cernik (Tom.Cernik@gmail.com) * @brief FeedForward model of Artifical neural network */ class Network: public NeuralNetwork::Network { public: /** * @brief Constructor for Network * @param _inputSize is number of inputs to network */ inline Network(size_t _inputSize):NeuralNetwork::Network(_inputSize,_inputSize),layers(),_partialInput(_inputSize),_partialOutput(_inputSize) { appendLayer(_inputSize); }; /** * @brief Virtual destructor for Network */ virtual ~Network() { for(auto &layer:layers) { delete layer; } } Layer& appendLayer(std::size_t size=1, const ActivationFunction::ActivationFunction &activationFunction=ActivationFunction::Sigmoid(-4.9)) { layers.push_back(new Layer(size,activationFunction)); if(layers.size() > 1) { layers.back()->setInputSize(layers[layers.size() - 2]->size()); } else { _inputs=size; } if(_partialInput.size() < size+1) { _partialInput.resize(size+1); } if(_partialOutput.size() < size+1) { _partialOutput.resize(size+1); } _outputs=size; return *layers.back(); } Layer& operator[](const std::size_t &id) { return *layers[id]; } void randomizeWeights(); std::size_t size() const { return layers.size(); }; /** * @brief This is a function to compute one iterations of network * @param input is input of network * @returns output of network */ virtual std::vector computeOutput(const std::vector& input) override; using NeuralNetwork::Network::stringify; virtual SimpleJSON::Type::Object serialize() const override { std::vector layersSerialized; for(std::size_t i=0;iserialize()); } return { {"class", "NeuralNetwork::FeedForward::Network"}, {"layers", layersSerialized }, }; } static std::unique_ptr deserialize(const SimpleJSON::Type::Object&); typedef SimpleJSON::Factory Factory; protected: std::vector layers; std::vector _partialInput = {}; std::vector _partialOutput = {}; private: inline Network():NeuralNetwork::Network(0,0),layers() { }; SIMPLEJSON_REGISTER(NeuralNetwork::FeedForward::Network::Factory, NeuralNetwork::FeedForward::Network,NeuralNetwork::FeedForward::Network::deserialize) }; } }