62 lines
1.4 KiB
C++
62 lines
1.4 KiB
C++
#pragma once
|
|
|
|
#include "BatchPropagation.h"
|
|
|
|
namespace NeuralNetwork {
|
|
namespace Learning {
|
|
|
|
/** @class BackPropagation
|
|
* @brief
|
|
*/
|
|
class BackPropagation : public BatchPropagation {
|
|
|
|
public:
|
|
BackPropagation(FeedForward::Network &feedForwardNetwork, std::shared_ptr<CorrectionFunction::CorrectionFunction> correction = std::make_shared<CorrectionFunction::Linear>()):
|
|
BatchPropagation(feedForwardNetwork,correction), learningCoefficient(0.4) {
|
|
resize();
|
|
}
|
|
|
|
BackPropagation(const BackPropagation&)=delete;
|
|
BackPropagation& operator=(const NeuralNetwork::Learning::BackPropagation&) = delete;
|
|
|
|
void setLearningCoefficient (const float& coefficient) {
|
|
learningCoefficient=coefficient;
|
|
}
|
|
|
|
float getMomentumWeight() const {
|
|
return momentumWeight;
|
|
}
|
|
|
|
void setMomentumWeight(const float& m) {
|
|
momentumWeight=m;
|
|
resize();
|
|
}
|
|
|
|
float getWeightDecay() const {
|
|
return weightDecay;
|
|
}
|
|
|
|
void setWeightDecay(const float& wd) {
|
|
weightDecay=wd;
|
|
}
|
|
|
|
protected:
|
|
|
|
virtual inline void resize() override {
|
|
BatchPropagation::resize();
|
|
if(momentumWeight > 0.0) {
|
|
_lastDeltas = _gradients;
|
|
}
|
|
}
|
|
|
|
virtual void updateWeightsAndEndBatch() override;
|
|
|
|
float learningCoefficient;
|
|
float momentumWeight = 0.0;
|
|
float weightDecay = 0.0;
|
|
|
|
std::vector<std::vector<std::vector<float>>> _lastDeltas = {};
|
|
|
|
};
|
|
}
|
|
} |