#pragma once #include #include #include #include "CorrectionFunction/Linear.h" namespace NeuralNetwork { namespace Learning { /** @class BackPropagation * @brief */ class BackPropagation { public: inline BackPropagation(FeedForward::Network &feedForwardNetwork, CorrectionFunction::CorrectionFunction *correction = new CorrectionFunction::Linear()): network(feedForwardNetwork), correctionFunction(correction),learningCoefficient(0.4), slopes() { resize(); } virtual ~BackPropagation() { delete correctionFunction; } BackPropagation(const BackPropagation&)=delete; BackPropagation& operator=(const NeuralNetwork::Learning::BackPropagation&) = delete; void teach(const std::vector &input, const std::vector &output); inline virtual void setLearningCoefficient (const float& coefficient) { learningCoefficient=coefficient; } float getMomentumWeight() const { return momentumWeight; } void setMomentumWeight(const float& m) { momentumWeight=m; resize(); } float getWeightDecay() const { return weightDecay; } void setWeightDecay(const float& wd) { weightDecay=wd; } std::size_t getBatchSize() const { return batchSize; } void setBatchSize(std::size_t size) { batchSize = size; } protected: virtual inline void resize() { if(slopes.size()!=network.size()) slopes.resize(network.size()); for(std::size_t i=0; i < network.size(); i++) { if(slopes[i].size()!=network[i].size()) slopes[i].resize(network[i].size()); } if(deltas.size() != network.size()) deltas.resize(network.size()); bool resized = false; for(std::size_t i = 0; i < network.size(); i++) { if(deltas[i].size() != network[i].size()) { deltas[i].resize(network[i].size()); resized = true; if(i > 0) { for(std::size_t j = 0; j < deltas[i].size(); j++) { deltas[i][j].resize(network[i - 1].size()); std::fill(deltas[i][j].begin(),deltas[i][j].end(),0.0); } } } } if(momentumWeight > 0.0 && (resized || lastDeltas.size() != deltas.size())) { lastDeltas = deltas; } } virtual void computeDeltas(const std::vector &input); void updateWeights(); virtual void computeSlopes(const std::vector &expectation); virtual void endBatch() { } FeedForward::Network &network; CorrectionFunction::CorrectionFunction *correctionFunction; float learningCoefficient; float momentumWeight = 0.0; float weightDecay = 0.0; std::size_t batchSize = 1; std::size_t currentBatchSize = 0; std::vector> slopes; std::vector>> deltas = {}; std::vector>> lastDeltas = {}; }; } }