#include #include #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Weffc++" #include #pragma GCC diagnostic pop TEST(OpticalBackPropagation,XOR) { NeuralNetwork::FeedForward::Network n(2); NeuralNetwork::ActivationFunction::Sigmoid a(-1); n.appendLayer(2,a); n.appendLayer(1,a); n.randomizeWeights(); NeuralNetwork::Learning::OpticalBackPropagation prop(n); for(int i=0;i<10000;i++) { prop.teach({1,0},{1}); prop.teach({1,1},{0}); prop.teach({0,0},{0}); prop.teach({0,1},{1}); } { std::vector ret =n.computeOutput({1,1}); ASSERT_LT(ret[0], 0.1); } { std::vector ret =n.computeOutput({0,1}); ASSERT_GT(ret[0], 0.9); } { std::vector ret =n.computeOutput({1,0}); ASSERT_GT(ret[0], 0.9); } { std::vector ret =n.computeOutput({0,0}); ASSERT_LT(ret[0], 0.1); } } TEST(OpticalBackPropagation,AND) { NeuralNetwork::FeedForward::Network n(2); NeuralNetwork::ActivationFunction::Sigmoid a(-1); n.appendLayer(2,a); n.appendLayer(1,a); n.randomizeWeights(); NeuralNetwork::Learning::OpticalBackPropagation prop(n); for(int i=0;i<10000;i++) { prop.teach({1,1},{1}); prop.teach({0,0},{0}); prop.teach({0,1},{0}); prop.teach({1,0},{0}); } { std::vector ret =n.computeOutput({1,1}); ASSERT_GT(ret[0], 0.9); } { std::vector ret =n.computeOutput({0,1}); ASSERT_LT(ret[0], 0.1); } { std::vector ret =n.computeOutput({1,0}); ASSERT_LT(ret[0], 0.1); } { std::vector ret =n.computeOutput({0,0}); ASSERT_LT(ret[0], 0.1); } } TEST(OpticalBackPropagation,NOTAND) { NeuralNetwork::FeedForward::Network n(2); NeuralNetwork::ActivationFunction::Sigmoid a(-1); n.appendLayer(2,a); n.appendLayer(1,a); n.randomizeWeights(); NeuralNetwork::Learning::OpticalBackPropagation prop(n); for(int i=0;i<10000;i++) { prop.teach({1,1},{0}); prop.teach({0,0},{1}); prop.teach({0,1},{1}); prop.teach({1,0},{1}); } { std::vector ret =n.computeOutput({1,1}); ASSERT_LT(ret[0], 0.1); } { std::vector ret =n.computeOutput({0,1}); ASSERT_GT(ret[0], 0.9); } { std::vector ret =n.computeOutput({1,0}); ASSERT_GT(ret[0], 0.9); } { std::vector ret =n.computeOutput({0,0}); ASSERT_GT(ret[0], 0.9); } }