35 lines
1.0 KiB
C++
35 lines
1.0 KiB
C++
#include <NeuralNetwork/Recurrent/Network.h>
|
|
|
|
#pragma GCC diagnostic push
|
|
#pragma GCC diagnostic ignored "-Weffc++"
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#pragma GCC diagnostic pop
|
|
|
|
TEST(Recurrent, Sample) {
|
|
NeuralNetwork::Recurrent::Network a(2,1,1);
|
|
|
|
a.getNeurons()[4]->weight(1)=0.05;
|
|
a.getNeurons()[4]->weight(2)=0.05;
|
|
a.getNeurons()[4]->weight(3)=0.7;
|
|
a.getNeurons()[3]->weight(4)=0.1;
|
|
|
|
std::vector <float> solutions({0.5,0.5732923,0.6077882,0.6103067,0.6113217,0.6113918,0.61142,0.6114219,0.6114227,0.6114227});
|
|
|
|
for(size_t i=0;i<solutions.size();i++) {
|
|
float res= a.computeOutput({1,0.7})[0];
|
|
ASSERT_FLOAT_EQ(res, solutions[i]);
|
|
}
|
|
|
|
std::string str = a.stringify();
|
|
|
|
//deserialize and check it!
|
|
NeuralNetwork::Recurrent::Network *deserialized = (NeuralNetwork::Recurrent::Network::Factory::deserialize(str).release());
|
|
for(size_t i=0;i<solutions.size();i++) {
|
|
float res= a.computeOutput({1,0.7})[0];
|
|
float resDeserialized= deserialized->computeOutput({1,0.7})[0];
|
|
assert(fabs(resDeserialized-res) < 0.01);
|
|
}
|
|
delete deserialized;
|
|
} |