28#include "tiny_dnn/layers/layer.h"
30#include "tiny_dnn/core/kernels/fully_connected_op.h"
31#include "tiny_dnn/core/kernels/fully_connected_grad_op.h"
38template<
typename Activation>
42 CNN_USE_LAYER_MEMBERS;
52 backend_t backend_type = core::default_engine())
53 :
Base(std_input_order(has_bias)) {
55 init_backend(backend_type);
56 Base::set_backend_type(backend_type);
61 : Base(std::
move(other))
62 , params_(std::
move(other.params_))
63 , kernel_fwd_(std::
move(other.kernel_fwd_))
64 , kernel_back_(std::
move(other.kernel_back_)) {
65 init_backend(std::move(other.engine()));
69 return params_.in_size_;
73 return params_.out_size_;
76 std::vector<index3d<serial_size_t>>
in_shape()
const override {
77 if (params_.has_bias_) {
80 params_.out_size_, 1),
85 params_.out_size_, 1) };
89 std::vector<index3d<serial_size_t>>
out_shape()
const override {
95 std::vector<tensor_t*>&
out_data)
override {
98 ctx.setParallelize(layer::parallelize());
99 ctx.setEngine(layer::engine());
102 kernel_fwd_->compute(
ctx);
109 const std::vector<tensor_t*>&
out_data,
111 std::vector<tensor_t*>&
in_grad)
override {
118 ctx.setParallelize(layer::parallelize());
119 ctx.setEngine(layer::engine());
122 kernel_back_->compute(
ctx);
125 std::string
layer_type()
const override {
return "fully-connected"; }
127 template <
class Archive>
128 static void load_and_construct(
Archive &
ar, cereal::construct<fully_connected_layer> & construct) {
132 ar(cereal::make_nvp(
"in_size",
in_dim),
133 cereal::make_nvp(
"out_size",
out_dim),
134 cereal::make_nvp(
"has_bias", has_bias));
138 template <
class Archive>
139 void serialize(Archive & ar) {
140 layer::serialize_prolog(ar);
141 ar(cereal::make_nvp(
"in_size", params_.in_size_),
142 cereal::make_nvp(
"out_size", params_.out_size_),
143 cereal::make_nvp(
"has_bias", params_.has_bias_));
148 void set_params(
const serial_size_t
in_size,
153 params_.has_bias_ = has_bias;
156 void init_backend(backend_t backend_type) {
157 core::OpKernelConstruction ctx =
158 core::OpKernelConstruction(layer::device(), ¶ms_);
160 if (backend_type == backend_t::internal ||
161 backend_type == backend_t::avx||
162 backend_type == backend_t::nnpack
165 kernel_fwd_.reset(
new FullyConnectedOp(ctx));
166 kernel_back_.reset(
new FullyConnectedGradOp(ctx));
171 throw nn_error(
"Not supported engine: " + to_string(backend_type));
177 fully_params params_;
180 std::shared_ptr<core::OpKernel> kernel_fwd_;
181 std::shared_ptr<core::OpKernel> kernel_back_;
single-input, single-output network with activation function
Definition feedforward_layer.h:37
compute fully-connected(matmul) operation
Definition fully_connected_layer.h:39
std::string layer_type() const override
name of layer, should be unique for each concrete class
Definition fully_connected_layer.h:125
void back_propagation(const std::vector< tensor_t * > &in_data, const std::vector< tensor_t * > &out_data, std::vector< tensor_t * > &out_grad, std::vector< tensor_t * > &in_grad) override
return delta of previous layer (delta=\frac{dE}{da}, a=wx in fully-connected layer)
Definition fully_connected_layer.h:108
fully_connected_layer(serial_size_t in_dim, serial_size_t out_dim, bool has_bias=true, backend_t backend_type=core::default_engine())
Definition fully_connected_layer.h:49
std::vector< index3d< serial_size_t > > out_shape() const override
array of output shapes (width x height x depth)
Definition fully_connected_layer.h:89
void forward_propagation(const std::vector< tensor_t * > &in_data, std::vector< tensor_t * > &out_data) override
Definition fully_connected_layer.h:94
serial_size_t fan_out_size() const override
number of outgoing connections for each input unit used only for weight/bias initialization methods w...
Definition fully_connected_layer.h:72
serial_size_t fan_in_size() const override
number of incoming connections for each output unit used only for weight/bias initialization methods ...
Definition fully_connected_layer.h:68
std::vector< index3d< serial_size_t > > in_shape() const override
array of input shapes (width x height x depth)
Definition fully_connected_layer.h:76
Simple image utility class.
Definition image.h:94
serial_size_t out_size() const
!
Definition layer.h:181
serial_size_t in_size() const
!
Definition layer.h:176