iRprop+ implementation
This commit is contained in:
64
include/NeuralNetwork/Learning/iRPropPlus.h
Normal file
64
include/NeuralNetwork/Learning/iRPropPlus.h
Normal file
@@ -0,0 +1,64 @@
|
||||
#pragma once
|
||||
|
||||
#include "BatchPropagation.h"
|
||||
|
||||
namespace NeuralNetwork {
|
||||
namespace Learning {
|
||||
|
||||
/** @class Resilient Propagation
|
||||
* @brief
|
||||
*/
|
||||
class iRPropPlus : public BatchPropagation {
|
||||
|
||||
public:
|
||||
iRPropPlus(FeedForward::Network &feedForwardNetwork, std::shared_ptr<CorrectionFunction::CorrectionFunction> correction = std::make_shared<CorrectionFunction::Linear>()):
|
||||
BatchPropagation(feedForwardNetwork, correction) {
|
||||
}
|
||||
|
||||
iRPropPlus(const iRPropPlus&)=delete;
|
||||
iRPropPlus& operator=(const NeuralNetwork::Learning::iRPropPlus&) = delete;
|
||||
|
||||
void setInitialWeightChange(float initVal) {
|
||||
initialWeightChange=initVal;
|
||||
}
|
||||
void setLearningCoefficient(float) {
|
||||
|
||||
}
|
||||
protected:
|
||||
|
||||
virtual inline void resize() override {
|
||||
BatchPropagation::resize();
|
||||
|
||||
_lastGradients =_gradients;
|
||||
|
||||
_changesOfWeightChanges = _lastGradients;
|
||||
for(std::size_t i = 1; i < _network.size(); i++) {
|
||||
for(std::size_t j = 0; j < _changesOfWeightChanges[i].size(); j++) {
|
||||
std::fill(_changesOfWeightChanges[i][j].begin(),_changesOfWeightChanges[i][j].end(),initialWeightChange);
|
||||
}
|
||||
}
|
||||
_lastWeightChanges = _lastGradients;
|
||||
for(std::size_t i = 1; i < _network.size(); i++) {
|
||||
for(std::size_t j = 0; j < _lastWeightChanges[i].size(); j++) {
|
||||
std::fill(_lastWeightChanges[i][j].begin(),_lastWeightChanges[i][j].end(),0.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void updateWeightsAndEndBatch() override;
|
||||
|
||||
std::vector<std::vector<std::vector<float>>> _lastGradients = {};
|
||||
std::vector<std::vector<std::vector<float>>> _lastWeightChanges = {};
|
||||
std::vector<std::vector<std::vector<float>>> _changesOfWeightChanges = {};
|
||||
|
||||
float _prevError=0;
|
||||
|
||||
float maxChangeOfWeights = 5;
|
||||
float minChangeOfWeights = 0.0001;
|
||||
|
||||
float initialWeightChange=0.02;
|
||||
float weightChangePlus=1.2;
|
||||
float weightChangeMinus=0.5;
|
||||
};
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user