tiny_dnn 1.0.0
A header only, dependency-free deep learning framework in C++11
Loading...
Searching...
No Matches
weight_init.h
1/*
2 Copyright (c) 2015, Taiga Nomi
3 All rights reserved.
4
5 Redistribution and use in source and binary forms, with or without
6 modification, are permitted provided that the following conditions are met:
7 * Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 * Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in the
11 documentation and/or other materials provided with the distribution.
12 * Neither the name of the <organization> nor the
13 names of its contributors may be used to endorse or promote products
14 derived from this software without specific prior written permission.
15
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
17 EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY
20 DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23 ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*/
27#pragma once
28#include "tiny_dnn/util/util.h"
29
30namespace tiny_dnn {
31namespace weight_init {
32
33class function {
34public:
35 virtual void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out) = 0;
36};
37
38class scalable : public function {
39public:
40 scalable(float_t value) : scale_(value) {}
41
42 void scale(float_t value) {
43 scale_ = value;
44 }
45protected:
46 float_t scale_;
47};
48
56class xavier : public scalable {
57public:
58 xavier() : scalable(float_t(6)) {}
59 explicit xavier(float_t value) : scalable(value) {}
60
61 void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out) override {
62 const float_t weight_base = std::sqrt(scale_ / (fan_in + fan_out));
63
64 uniform_rand(weight->begin(), weight->end(), -weight_base, weight_base);
65 }
66};
67
75class lecun : public scalable {
76public:
77 lecun() : scalable(float_t(1)) {}
78 explicit lecun(float_t value) : scalable(value) {}
79
80 void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out) override {
81 CNN_UNREFERENCED_PARAMETER(fan_out);
82
83 const float_t weight_base = scale_ / std::sqrt(float_t(fan_in));
84
85 uniform_rand(weight->begin(), weight->end(), -weight_base, weight_base);
86 }
87};
88
89class gaussian : public scalable {
90public:
91 gaussian() : scalable(float_t(1)) {}
92 explicit gaussian(float_t sigma) : scalable(sigma) {}
93
94 void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out) override {
95 CNN_UNREFERENCED_PARAMETER(fan_in);
96 CNN_UNREFERENCED_PARAMETER(fan_out);
97
98 gaussian_rand(weight->begin(), weight->end(), float_t(0), scale_);
99 }
100};
101
102class constant : public scalable {
103public:
104 constant() : scalable(float_t(0)) {}
105 explicit constant(float_t value) : scalable(value) {}
106
107 void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out) override {
108 CNN_UNREFERENCED_PARAMETER(fan_in);
109 CNN_UNREFERENCED_PARAMETER(fan_out);
110
111 std::fill(weight->begin(), weight->end(), scale_);
112 }
113};
114
115class he : public scalable {
116public:
117 he() : scalable(float_t(2)) {}
118 explicit he(float_t value) : scalable(value) {}
119
120 void fill(vec_t *weight, serial_size_t fan_in, serial_size_t fan_out) override {
121 CNN_UNREFERENCED_PARAMETER(fan_out);
122
123 const float_t sigma = std::sqrt(scale_ /fan_in);
124
125 gaussian_rand(weight->begin(), weight->end(), float_t(0), sigma);
126 }
127};
128
129} // namespace weight_init
130} // namespace tiny_dnn
Simple image utility class.
Definition image.h:94
Definition weight_init.h:102
Definition weight_init.h:33
Definition weight_init.h:89
Definition weight_init.h:115
Use fan-in(number of input weight for each neuron) for scaling.
Definition weight_init.h:75
Definition weight_init.h:38
Use fan-in and fan-out for scaling.
Definition weight_init.h:56