#pragma once #include "BatchPropagation.h" namespace NeuralNetwork { namespace Learning { /** @class BackPropagation * @brief */ class BackPropagation : public BatchPropagation { public: BackPropagation(FeedForward::Network &feedForwardNetwork, std::shared_ptr correction = std::make_shared()): BatchPropagation(feedForwardNetwork,correction), learningCoefficient(0.4) { resize(); } BackPropagation(const BackPropagation&)=delete; BackPropagation& operator=(const NeuralNetwork::Learning::BackPropagation&) = delete; void setLearningCoefficient (const float& coefficient) { learningCoefficient=coefficient; } float getMomentumWeight() const { return momentumWeight; } void setMomentumWeight(const float& m) { momentumWeight=m; resize(); } float getWeightDecay() const { return weightDecay; } void setWeightDecay(const float& wd) { weightDecay=wd; } protected: virtual inline void resize() override { BatchPropagation::resize(); if(momentumWeight > 0.0) { _lastDeltas = _gradients; } } virtual void updateWeightsAndEndBatch() override; float learningCoefficient; float momentumWeight = 0.0; float weightDecay = 0.0; std::vector>> _lastDeltas = {}; }; } }