add RectifiedUnit and LeakyRectifiedUnit

This commit is contained in:
2016-11-01 21:59:22 +01:00
parent 3006a1b237
commit 9f476b4233
2 changed files with 78 additions and 0 deletions

View File

@@ -0,0 +1,39 @@
#pragma once
#include "./ActivationFunction.h"
#include <cassert>
namespace NeuralNetwork {
namespace ActivationFunction {
class LeakyRectifiedLinear: public ActivationFunction {
public:
LeakyRectifiedLinear(const float &lambdaP=0.04): lambda(lambdaP) {}
inline virtual float derivatedOutput(const float &inp,const float &) const override {
return inp > 0.0f ? lambda : 0.01f*lambda;
}
inline virtual float operator()(const float &x) const override {
return x > 0.0? x : 0.001f*x;
};
virtual ActivationFunction* clone() const override {
return new LeakyRectifiedLinear(lambda);
}
virtual SimpleJSON::Type::Object serialize() const override {
return {{"class", "NeuralNetwork::ActivationFunction::LeakyRectifiedLinear"}, {"lambda", lambda}};
}
static std::unique_ptr<LeakyRectifiedLinear> deserialize(const SimpleJSON::Type::Object &obj) {
return std::unique_ptr<LeakyRectifiedLinear>(new LeakyRectifiedLinear(obj["lambda"].as<double>()));
}
protected:
float lambda;
NEURAL_NETWORK_REGISTER_ACTIVATION_FUNCTION(NeuralNetwork::ActivationFunction::LeakyRectifiedLinear, LeakyRectifiedLinear::deserialize)
};
}
}

View File

@@ -0,0 +1,39 @@
#pragma once
#include "./ActivationFunction.h"
#include <cassert>
namespace NeuralNetwork {
namespace ActivationFunction {
class RectifiedLinear: public ActivationFunction {
public:
RectifiedLinear(const float &lambdaP=0.1): lambda(lambdaP) {}
inline virtual float derivatedOutput(const float &inp,const float &) const override {
return inp > 0.0f ? lambda : 0.0f;
}
inline virtual float operator()(const float &x) const override {
return std::max(0.0f,x);
};
virtual ActivationFunction* clone() const override {
return new RectifiedLinear(lambda);
}
virtual SimpleJSON::Type::Object serialize() const override {
return {{"class", "NeuralNetwork::ActivationFunction::RectifiedLinear"}, {"lambda", lambda}};
}
static std::unique_ptr<RectifiedLinear> deserialize(const SimpleJSON::Type::Object &obj) {
return std::unique_ptr<RectifiedLinear>(new RectifiedLinear(obj["lambda"].as<double>()));
}
protected:
float lambda;
NEURAL_NETWORK_REGISTER_ACTIVATION_FUNCTION(NeuralNetwork::ActivationFunction::RectifiedLinear, RectifiedLinear::deserialize)
};
}
}