Files
NeuralNetworkLib/tests/nn-rl-xor2.cpp
2014-12-10 19:57:54 +01:00

99 lines
2.5 KiB
C++

#include "../src/NeuronNetwork/Learning/QLearning.h"
#include <iostream>
#include <vector>
class X: public Shin::NeuronNetwork::Problem
{
public:
X(const X& a) :Problem(a) {}
X(const std::vector<float> &a):Problem() {data=a;}
};
float atof(char *s)
{
int f, m, sign, d=1;
f = m = 0;
sign = (s[0] == '-') ? -1 : 1;
if (s[0] == '-' || s[0] == '+') s++;
for (; *s != '.' && *s; s++) {
f = (*s-'0') + f*10;
}
if (*s == '.')
for (++s; *s; s++) {
m = (*s-'0') + m*10;
d *= 10;
}
return sign*(f + (float)m/d);
}
float AA=10;
float getQuality(X& p, int action)
{
if((p[0]==0&& p[1]==0) ||(p[0]==1&& p[1]==1)) //should be 0
{
return action==1?-AA:AA;
}else // should be 1
{
return action==0?-AA:AA;
}
}
int main(int argc, char **argv)
{
srand(time(NULL));
Shin::NeuronNetwork::Learning::QLearning l(2,45,2);
if(argc==4 && argv[3][0]=='o')
{
std::cerr << "USING Optical Backpropagation\n";
l.opticalBackPropagation();
}
if(argc>=3)
{
std::cerr << "Setting learning coefficients to:" << atof(argv[1]) << "," << atof(argv[2]) << "\n";
l.setLearningCoeficient(atof(argv[1]),atof(argv[2]));
}
std::vector <std::pair<Shin::NeuronNetwork::Solution,Shin::NeuronNetwork::Problem>> p1x;
std::vector <X> states;
states.push_back(X(std::vector<float>({1,0})));
states.push_back(X(std::vector<float>({0,0})));
states.push_back(X(std::vector<float>({1,1})));
states.push_back(X(std::vector<float>({0,1})));
unsigned long step=0;
double quality=0;
while(step< 600000 && quality < (3.9*AA))
{
quality=0;
if(step%10000==0)
std::cerr << "STEP " << step << "\n";
for(unsigned i=0;i<states.size();i++)
{
int choice=l.getChoice(states[i]);
l.learn(states[i],choice,quality);
}
for(unsigned i=0;i<states.size();i++)
{
int choice=l.getChoice(states[i]);
quality+=getQuality(states[i],choice);
if(step%10000==0)
{
Shin::NeuronNetwork::Solution sol=l.getSolution(states[i]);
std::cerr << "\tState: [" << states[i][0] << "," << states[i][1] << "] Q-function: [" << sol[0] << "," <<sol[1] << "] Action " << choice << "\n";
}
}
step++;
}
std::cerr << step << "\n";
for(unsigned i=0;i<states.size();i++)
{
Shin::NeuronNetwork::Solution sol=l.getSolution(states[i]);
int choice=l.getChoice(states[i]);
std::cerr << "State: [" << states[i][0] << "," << states[i][1] << "] Q-function: [" << sol[0] << "," <<sol[1] << "] Action " << choice << "\n";
}
}