Files
NeuralNetworkLib/include/NeuralNetwork/ConstructiveAlgorithms/CelularEncoding/CelularEncoding.h
2016-05-18 22:57:06 +02:00

182 lines
4.6 KiB
C++

#pragma once
#include "Exception.h"
#include "Cell.h"
#include <NeuralNetwork/Recurrent/Network.h>
#include <vector>
#include <algorithm>
namespace NeuralNetworks {
namespace ConstructiveAlgorithms {
namespace CelularEncoding {
class CelularEncoding {
public:
CelularEncoding(const CelularEncoding &) = delete;
CelularEncoding &operator=(const CelularEncoding &) = delete;
CelularEncoding() {
}
void setActivationFunction(const std::shared_ptr<NeuralNetwork::ActivationFunction::ActivationFunction> &fun) {
_activationFunction=fun;
}
void setMaxSteps(std::size_t steps) {
_maxSteps=steps;
}
NeuralNetwork::Recurrent::Network run() {
std::size_t cellsStep = 0;
std::size_t steps=0;
do {
cellsStep = step();
steps++;
}
while(cellsStep > 0 && steps < _maxSteps);
if(steps >= _maxSteps) {
throw Exception("Went over max steps");
}
if(cells.size() > _maxCells) {
throw Exception("Went over max cells");
}
std::size_t outputs = 0;
std::size_t inputs = 0;
std::vector<std::size_t> cells2Neurons;
cells2Neurons.resize(cells.size());
std::size_t indexOfNeuronTmp=1;
for(std::size_t i = 0; i < cells.size(); i++) {
if(cells[i]->isInput()) {
cells2Neurons[i] = indexOfNeuronTmp++;
inputs++;
}
}
for(std::size_t i = 0; i < cells.size(); i++) {
if(cells[i]->isOutput()) {
if(!cells[i]->isInput()) {
cells2Neurons[i] = indexOfNeuronTmp++;
}
outputs++;
}
}
for(std::size_t i = 0; i < cells.size(); i++) {
if(!cells[i]->isOutput() && !cells[i]->isInput()) {
cells2Neurons[i] = indexOfNeuronTmp++;
}
}
std::size_t hiddenNeurons = static_cast<int>(cells.size()) - static_cast<int>(inputs) - static_cast<int>(outputs) < 0 ? 0 : cells.size() - inputs - outputs;
NeuralNetwork::Recurrent::Network netw(inputs, outputs, hiddenNeurons);
for(std::size_t i = 0; i < cells.size(); i++) {
const auto &cell = cells[i];
std::size_t indexOfNeuron = cells2Neurons[i];
auto& neuron = netw[indexOfNeuron];
if(cells2Neurons[i] > inputs) {
neuron.setActivationFunction(*_activationFunction);
}
neuron.weight(0)=cell->getBias();
for(auto &link: cells[i]->getLinks()) {
if(link.status == true) {
neuron.weight(cells2Neurons[link.neuron]) = link.value;
} else {
neuron.weight(cells2Neurons[link.neuron]) = 0.0;
}
}
}
return netw;
}
Cell &addCell(const EvolutionaryAlgorithm::GeneticPrograming::CodeTree *c) {
cells.push_back(std::make_shared<Cell>(cells.size(), c));
return (*cells.back());
}
void setAcyclicTopology() {
cells.clear();
/*
for(std::size_t i = 0; i < inputSize; i++) {
addCell(code).die();
}
*/
Cell &cell = addCell(code);
cell.setLife(_initialLife);
_processingOrder.push_back(cell.getID());
cell.setOutput();
cell.setInput();
/*
for(std::size_t i = 0; i < inputSize; i++) {
Link l(true, 1.0, i);
cell.addLink(l);
}
*/
}
void setCyclicTopology() {
setAcyclicTopology();
// Acyclic + reccurent link
Link l(true, 1.0, cells.back()->getID());
cells.back()->addLink(l);
}
void setCode(const EvolutionaryAlgorithm::GeneticPrograming::CodeTree *code_) {
code = code_;
}
const EvolutionaryAlgorithm::GeneticPrograming::CodeTree *getCodeStart() const {
return code;
}
std::vector<std::shared_ptr<Cell>> &getCells() {
return cells;
}
void addCellToProcessingOrder(std::size_t id) {
auto position = std::find(_processingOrder.begin(),_processingOrder.end(),currentID);
if(position == _processingOrder.end()) {
_processingOrder.push_back(id);
} else {
_processingOrder.insert(position+1,id);
}
}
void setInitiaLife(std::size_t life) {
_initialLife=life;
}
protected:
std::size_t step();
private:
std::size_t _maxCells= 15;
std::size_t _maxSteps = std::numeric_limits<std::size_t>::max();
std::size_t _initialLife = 2.0;
std::shared_ptr<NeuralNetwork::ActivationFunction::ActivationFunction> _activationFunction = std::make_shared<NeuralNetwork::ActivationFunction::Sigmoid>();
std::vector<std::size_t> _processingOrder = {};
std::size_t currentID = 0;
const EvolutionaryAlgorithm::GeneticPrograming::CodeTree *code = nullptr;
std::vector<std::shared_ptr<Cell>> cells = {};
};
}
}
}