39 lines
983 B
C++
39 lines
983 B
C++
#pragma once
|
|
|
|
#include <mmintrin.h>
|
|
#include <xmmintrin.h>
|
|
#include <emmintrin.h>
|
|
#include <xmmintrin.h>
|
|
#include <pmmintrin.h>
|
|
#include <immintrin.h>
|
|
|
|
#include <cassert>
|
|
|
|
#include "./StreamingBasisFunction.h"
|
|
|
|
#include "../../sse_mathfun.h"
|
|
|
|
namespace NeuralNetwork {
|
|
namespace BasisFunction {
|
|
|
|
class Linear: public StreamingBasisFunction {
|
|
public:
|
|
Linear() {}
|
|
|
|
virtual float operator()(const std::vector<float>& weights, const std::vector<float>& input) const override;
|
|
|
|
virtual std::unique_ptr<BasisFunction> clone() const override {
|
|
return std::unique_ptr<BasisFunction>(new Linear());
|
|
}
|
|
|
|
virtual SimpleJSON::Type::Object serialize() const override {
|
|
return {{"class", "NeuralNetwork::BasisFunction::Linear"}};
|
|
}
|
|
|
|
static std::unique_ptr<Linear> deserialize(const SimpleJSON::Type::Object &) {
|
|
return std::unique_ptr<Linear>(new Linear());
|
|
}
|
|
NEURAL_NETWORK_REGISTER_BASIS_FUNCTION(NeuralNetwork::BasisFunction::Linear, deserialize)
|
|
};
|
|
}
|
|
} |