#pragma once #include #include "CorrectionFunction/Linear.h" #include #include namespace NeuralNetwork { namespace Learning { class BatchPropagation { public: BatchPropagation(FeedForward::Network &ffn, std::shared_ptr correction) : _network(ffn), _correctionFunction(correction) { } virtual ~BatchPropagation() { } void teach(const std::vector &input, const std::vector &output); void finishTeaching(); std::size_t getBatchSize() const { return _batchSize; } void setBatchSize(std::size_t size) { _batchSize = size; } protected: virtual void updateWeightsAndEndBatch() = 0; virtual void resize(); FeedForward::Network &_network; std::shared_ptr _correctionFunction; std::size_t _batchSize = 1; std::size_t _currentBatchSize = 0; std::vector> _slopes = {}; std::vector>> _gradients = {}; bool init = false; private: void computeSlopes(const std::vector &expectation); void computeDeltas(const std::vector &input); }; } }