Files
NeuralNetworkLib/include/NeuralNetwork/Learning/BatchPropagation.h
2016-10-31 15:03:27 +01:00

52 lines
1.2 KiB
C++

#pragma once
#include <NeuralNetwork/FeedForward/Network.h>
#include "CorrectionFunction/Linear.h"
#include <vector>
#include <memory>
namespace NeuralNetwork {
namespace Learning {
class BatchPropagation {
public:
BatchPropagation(FeedForward::Network &ffn, std::shared_ptr<CorrectionFunction::CorrectionFunction> correction) : _network(ffn), _correctionFunction(correction) {
}
virtual ~BatchPropagation() {
}
void teach(const std::vector<float> &input, const std::vector<float> &output);
void finishTeaching();
std::size_t getBatchSize() const {
return _batchSize;
}
void setBatchSize(std::size_t size) {
_batchSize = size;
}
protected:
virtual void updateWeightsAndEndBatch() = 0;
virtual void resize();
FeedForward::Network &_network;
std::shared_ptr<CorrectionFunction::CorrectionFunction> _correctionFunction;
std::size_t _batchSize = 1;
std::size_t _currentBatchSize = 0;
std::vector<std::vector<float>> _slopes = {};
std::vector<std::vector<std::vector<float>>> _gradients = {};
bool init = false;
private:
void computeSlopes(const std::vector<float> &expectation);
void computeDeltas(const std::vector<float> &input);
};
}
}