#include "../src/NeuronNetwork/FeedForwardQuick" #include "../src/NeuronNetwork/Learning/Reinforcement" #include "../src/NeuronNetwork/Learning/OpticalBackPropagation" #include #include class X: public Shin::NeuronNetwork::Problem { public: X(const X& a) :q(a.q) {} X(const std::vector &a):q(a) {} std::vector representation() const { return q; } protected: std::vector q; }; int main() { srand(time(NULL)); for (int test=0;test<3;test++) { Shin::NeuronNetwork::FeedForwardNetworkQuick q({2,6,1}); Shin::NeuronNetwork::Learning::Reinforcement b(q); b.setPropagator(new Shin::NeuronNetwork::Learning::OpticalBackPropagation(q)); b.getPropagator().setLearningCoeficient(0.9); b.getPropagator().allowEntropy(); double targetQuality =1.7; if(test==2) { targetQuality =1.62; std::cerr << "Testing with OBP ...\n"; b.setPropagator(new Shin::NeuronNetwork::Learning::OpticalBackPropagation(q)); b.getPropagator().setLearningCoeficient(3); } b.setQualityFunction( [](const Shin::NeuronNetwork::Problem &pr,const Shin::NeuronNetwork::Solution &s)->float { std::vector p=pr; float expect=0.0; if(p[0] && p[1]) expect=0; else if(p[0] && !p[1]) expect=1; else if(!p[0] && !p[1]) expect=0; else if(!p[0] && p[1]) expect=1; // std::cerr << "expected: " << expect << " got " << s[0]; if(expect==0) { expect=0.33-s[0]; }else { expect=s[0]-0.67; } // std::cerr << " returnning " << expect*5.0 << "\n"; return expect*9.0; }); std::vector p; p.push_back(new X(std::vector({0,0}))); p.push_back( new X(std::vector({1,0}))); p.push_back( new X(std::vector({0,1}))); p.push_back(new X(std::vector({1,1}))); if(test==1) { std::cerr << "Testing with entropy ...\n"; b.getPropagator().allowEntropy(); }else { std::cerr << "Testing without entropy ...\n"; } for(int i=0;i < 500000000;i++) // for(int i=0;i < 5;i++) { double err=b.learnSet(p); if(i%100000==0) srand(time(NULL)); if(i%200000==0 || err > targetQuality) { std::cerr << i << " ("<< err <<").\n"; for(int j=0;j<4;j++) { std::cerr << "\t" << j%4 << ". FOR: [" << p[j%4]->representation()[0] << "," <representation()[1] << "] res: " << q.solve(*p[j%4])[0] << "\n"; } } if(err >targetQuality) break; } } }