52 lines
1.2 KiB
C++
52 lines
1.2 KiB
C++
#pragma once
|
|
|
|
#include <NeuralNetwork/FeedForward/Network.h>
|
|
|
|
#include "CorrectionFunction/Linear.h"
|
|
|
|
#include <vector>
|
|
#include <memory>
|
|
|
|
namespace NeuralNetwork {
|
|
namespace Learning {
|
|
class BatchPropagation {
|
|
public:
|
|
BatchPropagation(FeedForward::Network &ffn, std::shared_ptr<CorrectionFunction::CorrectionFunction> correction) : _network(ffn), _correctionFunction(correction) {
|
|
|
|
}
|
|
|
|
virtual ~BatchPropagation() {
|
|
|
|
}
|
|
|
|
void teach(const std::vector<float> &input, const std::vector<float> &output);
|
|
|
|
void finishTeaching();
|
|
|
|
std::size_t getBatchSize() const {
|
|
return _batchSize;
|
|
}
|
|
|
|
void setBatchSize(std::size_t size) {
|
|
_batchSize = size;
|
|
}
|
|
protected:
|
|
virtual void updateWeightsAndEndBatch() = 0;
|
|
virtual void resize();
|
|
|
|
FeedForward::Network &_network;
|
|
std::shared_ptr<CorrectionFunction::CorrectionFunction> _correctionFunction;
|
|
|
|
std::size_t _batchSize = 1;
|
|
std::size_t _currentBatchSize = 0;
|
|
|
|
std::vector<std::vector<float>> _slopes = {};
|
|
std::vector<std::vector<std::vector<float>>> _gradients = {};
|
|
|
|
bool init = false;
|
|
private:
|
|
void computeSlopes(const std::vector<float> &expectation);
|
|
void computeDeltas(const std::vector<float> &input);
|
|
};
|
|
}
|
|
} |