This commit is contained in:
2016-05-21 00:18:44 +02:00
parent 4af18f014c
commit 368a73ccd6

View File

@@ -31,6 +31,7 @@ namespace NeuralNetwork {
network.randomizeWeights(); network.randomizeWeights();
_epoch = 0; _epoch = 0;
_neurons = 0;
float error; float error;
float lastError; float lastError;
if(_maxRandomOutputWeights) { if(_maxRandomOutputWeights) {
@@ -38,7 +39,7 @@ namespace NeuralNetwork {
} else { } else {
error = trainOutputs(network, patterns); error = trainOutputs(network, patterns);
} }
while(_epoch++ < _maxHiddenUnits && error > _errorTreshold) { while(_epoch++ < _maxHiddenUnits && _neurons++ < _maxEpochs && error > _errorTreshold) {
std::vector<std::shared_ptr<Neuron>> candidates = createCandidates(network.getNeuronSize() - outputs); std::vector<std::shared_ptr<Neuron>> candidates = createCandidates(network.getNeuronSize() - outputs);
std::pair<std::shared_ptr<Neuron>, std::vector<float>> candidate = trainCandidates(network, candidates, patterns); std::pair<std::shared_ptr<Neuron>, std::vector<float>> candidate = trainCandidates(network, candidates, patterns);
@@ -58,6 +59,8 @@ namespace NeuralNetwork {
network.removeLastHiddenNeuron(); network.removeLastHiddenNeuron();
error=lastError; error=lastError;
std::cout << "PRUNED\n"; std::cout << "PRUNED\n";
} else {
_neurons++;
} }
} }
@@ -77,6 +80,10 @@ namespace NeuralNetwork {
_maxHiddenUnits = neurons; _maxHiddenUnits = neurons;
} }
void setMaximumEpochs(std::size_t epochs) {
_maxEpochs = epochs;
}
void setActivationFunction(const ActivationFunction::ActivationFunction &function) { void setActivationFunction(const ActivationFunction::ActivationFunction &function) {
_activFunction = std::shared_ptr<ActivationFunction::ActivationFunction>(function.clone()); _activFunction = std::shared_ptr<ActivationFunction::ActivationFunction>(function.clone());
} }
@@ -160,6 +167,8 @@ namespace NeuralNetwork {
float _pruningLimit=0.0; float _pruningLimit=0.0;
std::size_t _epoch = 0; std::size_t _epoch = 0;
std::size_t _neurons = 0;
std::size_t _maxEpochs = 20;
std::size_t _maxHiddenUnits = 20; std::size_t _maxHiddenUnits = 20;
std::size_t _maxRandomOutputWeights = 0; std::size_t _maxRandomOutputWeights = 0;
std::size_t _numberOfCandidates; std::size_t _numberOfCandidates;