cascade network implementation

This commit is contained in:
2016-05-02 16:58:28 +02:00
parent 2cf680a942
commit 6a17694a6b

View File

@@ -0,0 +1,105 @@
#pragma once
#include "../Network.h"
namespace NeuralNetwork {
namespace Cascade {
class Network : public NeuralNetwork::Network {
public:
/**
* @brief Constructor for Network
* @param _inputSize is number of inputs to network
*/
Network(std::size_t inputSize, std::size_t outputSize) : NeuralNetwork::Network(), _inputSize(inputSize), _outputSize(outputSize) {
_neurons.push_back(std::make_shared<BiasNeuron>());
for(std::size_t i = 0; i < inputSize; i++) {
_neurons.push_back(std::make_shared<InputNeuron>(_neurons.size()));
}
for(std::size_t i = 0; i < outputSize; i++) {
_neurons.push_back(std::make_shared<Neuron>(_neurons.size()));
_neurons.back()->setInputSize(inputSize + 1); // +1 is bias
}
}
virtual std::vector<float> computeOutput(const std::vector<float> &input) override {
std::vector<float> compute;
compute.resize(_neurons.size());
compute[0] = 1.0;
for(std::size_t i = 1; i <= _inputSize; i++) {
compute[i] = input[i - 1];
}
// 0 is bias, 1-_inputSize is input
for(std::size_t i = _inputSize + 1; i < _neurons.size(); i++) {
compute[i] = (*_neurons[i].get())(compute);
}
return std::vector<float>(compute.end() - _outputSize, compute.end());
}
std::size_t getNeuronSize() const {
return _neurons.size();
}
std::shared_ptr<NeuronInterface> getNeuron(std::size_t id) {
return _neurons[id];
}
std::shared_ptr<NeuronInterface> addNeuron() {
_neurons.push_back(std::make_shared<Neuron>());
auto neuron = _neurons.back();
neuron->setInputSize(_neurons.size() - _outputSize);
// 0 is bias, 1-_inputSize is input
std::size_t maxIndexOfNeuron = _neurons.size() - 1;
// move output to right position
for(std::size_t i = 0; i < _outputSize; i++) {
std::swap(_neurons[maxIndexOfNeuron - i], _neurons[maxIndexOfNeuron - i - 1]);
}
for(std::size_t i = 0; i < _outputSize; i++) {
_neurons[maxIndexOfNeuron - i]->setInputSize(_neurons.size() - _outputSize);
}
return neuron;
}
virtual SimpleJSON::Type::Object serialize() const override {
std::vector<SimpleJSON::Value> neuronsSerialized;
for(auto &neuron: _neurons) {
neuronsSerialized.push_back(neuron->serialize());
}
return {
{"class", "NeuralNetwork::Recurrent::Network"},
{"inputSize", _inputSize},
{"outputSize", _outputSize},
{"neurons", neuronsSerialized}
};
}
static std::unique_ptr<Network> deserialize(const SimpleJSON::Type::Object &obj) {
const int inputSize = obj["inputSize"].as<int>();
const int outputSize = obj["outputSize"].as<int>();
Network *net = new Network(inputSize, outputSize);
net->_neurons.clear();
for(const auto& neuronObj: obj["neurons"].as<SimpleJSON::Type::Array>()) {
net->_neurons.push_back(Neuron::Factory::deserialize(neuronObj.as<SimpleJSON::Type::Object>()));
}
return std::unique_ptr<Network>(net);
}
protected:
std::size_t _inputSize;
std::size_t _outputSize;
std::vector<std::shared_ptr<NeuronInterface>> _neurons = {};
SIMPLEJSON_REGISTER(NeuralNetwork::Cascade::Network::Factory, NeuralNetwork::Cascade::Network, deserialize)
};
}
}