epochs
This commit is contained in:
@@ -31,6 +31,7 @@ namespace NeuralNetwork {
|
|||||||
network.randomizeWeights();
|
network.randomizeWeights();
|
||||||
|
|
||||||
_epoch = 0;
|
_epoch = 0;
|
||||||
|
_neurons = 0;
|
||||||
float error;
|
float error;
|
||||||
float lastError;
|
float lastError;
|
||||||
if(_maxRandomOutputWeights) {
|
if(_maxRandomOutputWeights) {
|
||||||
@@ -38,7 +39,7 @@ namespace NeuralNetwork {
|
|||||||
} else {
|
} else {
|
||||||
error = trainOutputs(network, patterns);
|
error = trainOutputs(network, patterns);
|
||||||
}
|
}
|
||||||
while(_epoch++ < _maxHiddenUnits && error > _errorTreshold) {
|
while(_epoch++ < _maxHiddenUnits && _neurons++ < _maxEpochs && error > _errorTreshold) {
|
||||||
std::vector<std::shared_ptr<Neuron>> candidates = createCandidates(network.getNeuronSize() - outputs);
|
std::vector<std::shared_ptr<Neuron>> candidates = createCandidates(network.getNeuronSize() - outputs);
|
||||||
|
|
||||||
std::pair<std::shared_ptr<Neuron>, std::vector<float>> candidate = trainCandidates(network, candidates, patterns);
|
std::pair<std::shared_ptr<Neuron>, std::vector<float>> candidate = trainCandidates(network, candidates, patterns);
|
||||||
@@ -58,6 +59,8 @@ namespace NeuralNetwork {
|
|||||||
network.removeLastHiddenNeuron();
|
network.removeLastHiddenNeuron();
|
||||||
error=lastError;
|
error=lastError;
|
||||||
std::cout << "PRUNED\n";
|
std::cout << "PRUNED\n";
|
||||||
|
} else {
|
||||||
|
_neurons++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,6 +80,10 @@ namespace NeuralNetwork {
|
|||||||
_maxHiddenUnits = neurons;
|
_maxHiddenUnits = neurons;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void setMaximumEpochs(std::size_t epochs) {
|
||||||
|
_maxEpochs = epochs;
|
||||||
|
}
|
||||||
|
|
||||||
void setActivationFunction(const ActivationFunction::ActivationFunction &function) {
|
void setActivationFunction(const ActivationFunction::ActivationFunction &function) {
|
||||||
_activFunction = std::shared_ptr<ActivationFunction::ActivationFunction>(function.clone());
|
_activFunction = std::shared_ptr<ActivationFunction::ActivationFunction>(function.clone());
|
||||||
}
|
}
|
||||||
@@ -160,6 +167,8 @@ namespace NeuralNetwork {
|
|||||||
float _pruningLimit=0.0;
|
float _pruningLimit=0.0;
|
||||||
|
|
||||||
std::size_t _epoch = 0;
|
std::size_t _epoch = 0;
|
||||||
|
std::size_t _neurons = 0;
|
||||||
|
std::size_t _maxEpochs = 20;
|
||||||
std::size_t _maxHiddenUnits = 20;
|
std::size_t _maxHiddenUnits = 20;
|
||||||
std::size_t _maxRandomOutputWeights = 0;
|
std::size_t _maxRandomOutputWeights = 0;
|
||||||
std::size_t _numberOfCandidates;
|
std::size_t _numberOfCandidates;
|
||||||
|
|||||||
Reference in New Issue
Block a user