40 CNN_USE_LAYER_MEMBERS;
74 :
Base({vector_type::data}),
89 return in2out_[0].size();
92 void forward_propagation(serial_size_t index,
93 const std::vector<vec_t*>&
in_data,
94 std::vector<vec_t*>&
out_data)
override {
98 std::vector<serial_size_t>&
max_idx = max_unpooling_layer_worker_storage_[index].in2outmax_;
101 for (int i = r.begin(); i < r.end(); i++) {
102 const auto& in_index = out2in_[i];
103 a[i] = (max_idx[in_index] == i) ? in[in_index] : float_t(0);
107 this->forward_activation(*out_data[0], *out_data[1]);
110 void back_propagation(serial_size_t index,
111 const std::vector<vec_t*>& in_data,
112 const std::vector<vec_t*>& out_data,
113 std::vector<vec_t*>& out_grad,
114 std::vector<vec_t*>& in_grad)
override {
115 vec_t& prev_delta = *in_grad[0];
116 vec_t& curr_delta = *out_grad[1];
117 std::vector<serial_size_t>& max_idx = max_unpooling_layer_worker_storage_[index].in2outmax_;
119 CNN_UNREFERENCED_PARAMETER(in_data);
121 this->backward_activation(*out_grad[0], *out_data[0], curr_delta);
123 for_(parallelize_, 0, in2out_.size(), [&](
const blocked_range& r) {
124 for (int i = r.begin(); i != r.end(); i++) {
125 serial_size_t outi = out2in_[i];
126 prev_delta[i] = (max_idx[outi] == i) ? curr_delta[outi] : float_t(0);
131 std::vector<index3d<serial_size_t>>
in_shape()
const override {
return {in_}; }
132 std::vector<index3d<serial_size_t>>
out_shape()
const override {
return {out_, out_}; }
133 std::string
layer_type()
const override {
return "max-unpool"; }
134 size_t unpool_size()
const {
return unpool_size_;}
136 virtual void set_worker_count(serial_size_t worker_count)
override {
137 Base::set_worker_count(worker_count);
138 max_unpooling_layer_worker_storage_.resize(worker_count);
139 for (max_unpooling_layer_worker_specific_storage& mws : max_unpooling_layer_worker_storage_) {
140 mws.in2outmax_.resize(out_.size());
144 template <
class Archive>
145 static void load_and_construct(Archive & ar, cereal::construct<max_unpooling_layer> & construct) {
147 serial_size_t stride, unpool_size;
149 ar(cereal::make_nvp(
"in_size", in), cereal::make_nvp(
"unpool_size", unpool_size), cereal::make_nvp(
"stride", stride));
150 construct(in, unpool_size, stride);
153 template <
class Archive>
154 void serialize(Archive & ar) {
155 layer::serialize_prolog(ar);
156 ar(cereal::make_nvp(
"in_size", in_), cereal::make_nvp(
"unpool_size", unpool_size_), cereal::make_nvp(
"stride", stride_));
160 serial_size_t unpool_size_;
161 serial_size_t stride_;
162 std::vector<serial_size_t> out2in_;
163 std::vector<std::vector<serial_size_t> > in2out_;
165 struct max_unpooling_layer_worker_specific_storage {
166 std::vector<serial_size_t> in2outmax_;
169 std::vector<max_unpooling_layer_worker_specific_storage> max_unpooling_layer_worker_storage_;
171 index3d<serial_size_t> in_;
172 index3d<serial_size_t> out_;
174 static serial_size_t unpool_out_dim(serial_size_t in_size, serial_size_t unpooling_size, serial_size_t stride) {
175 return (
int) (float_t)in_size * stride + unpooling_size - 1;
178 void connect_kernel(serial_size_t unpooling_size, serial_size_t inx, serial_size_t iny, serial_size_t c)
180 serial_size_t dxmax =
static_cast<serial_size_t
>(std::min(unpooling_size, inx * stride_ - out_.width_));
181 serial_size_t dymax =
static_cast<serial_size_t
>(std::min(unpooling_size, iny * stride_ - out_.height_));
183 for (serial_size_t dy = 0; dy < dymax; dy++) {
184 for (serial_size_t dx = 0; dx < dxmax; dx++) {
185 serial_size_t out_index = out_.get_index(
static_cast<serial_size_t
>(inx * stride_ + dx),
186 static_cast<serial_size_t
>(iny * stride_ + dy), c);
187 serial_size_t in_index = in_.get_index(inx, iny, c);
189 if (in_index >= in2out_.size())
190 throw nn_error(
"index overflow");
191 if (out_index >= out2in_.size())
192 throw nn_error(
"index overflow");
193 out2in_[out_index] = in_index;
194 in2out_[in_index].push_back(out_index);
199 void init_connection()
201 in2out_.resize(in_.size());
202 out2in_.resize(out_.size());
204 for (max_unpooling_layer_worker_specific_storage& mws : max_unpooling_layer_worker_storage_) {
205 mws.in2outmax_.resize(in_.size());
208 for (serial_size_t c = 0; c < in_.depth_; ++c)
209 for (serial_size_t y = 0; y < in_.height_; ++y)
210 for (serial_size_t x = 0; x < in_.width_; ++x)
211 connect_kernel(
static_cast<serial_size_t
>(unpool_size_),