tiny_dnn 1.0.0
A header only, dependency-free deep learning framework in C++11
Loading...
Searching...
No Matches
conv2d_grad_op_avx.h
1/*
2 Copyright (c) 2016, Taiga Nomi, Edgar Riba
3 All rights reserved.
4
5 Redistribution and use in source and binary forms, with or without
6 modification, are permitted provided that the following conditions are met:
7 * Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 * Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in the
11 documentation and/or other materials provided with the distribution.
12 * Neither the name of the <organization> nor the
13 names of its contributors may be used to endorse or promote products
14 derived from this software without specific prior written permission.
15
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
17 EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY
20 DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23 ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*/
27#pragma once
28
29#include <vector>
30#include "tiny_dnn/core/params/conv_params.h"
31#include "tiny_dnn/core/kernels/conv2d_op_internal.h"
32
33#ifdef CNN_USE_AVX
34#include "tiny_dnn/core/kernels/avx_kernel_common.h"
35#endif
36
37namespace tiny_dnn {
38namespace kernels {
39
40#ifdef CNN_USE_AVX
41
42// float ver
43template <typename Allocator>
44void avx_conv2d_5x5_back_kernel_one(const core::conv_params& params,
45 const std::vector<float, Allocator>& prev_out,
46 const std::vector<float, Allocator>& W,
47 std::vector<float, Allocator>& dW,
48 std::vector<float, Allocator>& db,
49 std::vector<float, Allocator>& curr_delta,
50 std::vector<float, Allocator>* prev_delta) {
51 auto& in = params.in;
52 auto& out = params.out;
53 auto& in_padded = params.in_padded;
54 auto& tbl = params.tbl;
55 auto w_stride = params.w_stride;
56 const size_t in_padded_area = in_padded.area();
57 float* pdelta_dst_org = &(*prev_delta)[0];
58 const size_t h_stride2 = params.h_stride * in_padded.width_;
59 static const __m256i imask = _mm256_setr_epi32(-1, -1, -1, -1, -1, 0, 0, 0);
60 static const __m256 mask = _mm256_castsi256_ps(_mm256_setr_epi32(-1, -1, -1, -1, -1, 0, 0, 0));
61 // propagate delta to previous layer
62 if (w_stride == 1 && out.width_ >= 4) {
63 const serial_size_t nblocks = out.width_ / 4;
64 for (serial_size_t inc = 0; inc < in.depth_; ++inc, pdelta_dst_org += in_padded_area) {
65 for (serial_size_t outc = 0; outc < out.depth_; outc++) {
66 if (!tbl.is_connected(outc, inc)) continue;
67 const float* pw = &W[25 * (in.depth_ * outc + inc)];
68 const float* pdelta_src = &curr_delta[out.get_index(0, 0, outc)];
69 float* pdelta_dst = pdelta_dst_org;
70 __m256 w0a = _mm256_and_ps(_mm256_loadu_ps(pw+0), mask);
71 __m256 w1a = _mm256_and_ps(_mm256_loadu_ps(pw+5), mask);
72 __m256 w2a = _mm256_and_ps(_mm256_loadu_ps(pw+10), mask);
73 __m256 w3a = _mm256_and_ps(_mm256_loadu_ps(pw+15), mask);
74 __m256 w4a = _mm256_and_ps(_mm256_loadu_ps(pw+20), mask);
75 __m256 w0b = leftShift<4>(w0a);
76 __m256 w1b = leftShift<4>(w1a);
77 __m256 w2b = leftShift<4>(w2a);
78 __m256 w3b = leftShift<4>(w3a);
79 __m256 w4b = leftShift<4>(w4a);
80 __m256 w0c = leftShift<8>(w0a);
81 __m256 w1c = leftShift<8>(w1a);
82 __m256 w2c = leftShift<8>(w2a);
83 __m256 w3c = leftShift<8>(w3a);
84 __m256 w4c = leftShift<8>(w4a);
85 __m256 w0d = leftShift<12>(w0a);
86 __m256 w1d = leftShift<12>(w1a);
87 __m256 w2d = leftShift<12>(w2a);
88 __m256 w3d = leftShift<12>(w3a);
89 __m256 w4d = leftShift<12>(w4a);
90 for (serial_size_t y = 0; y < out.height_; y++) {
91 const float* pdelta_src2 = pdelta_src;
92 float* delta_dst0 = pdelta_dst;
93 float* delta_dst1 = &pdelta_dst[in_padded.width_ * 1];
94 float* delta_dst2 = &pdelta_dst[in_padded.width_ * 2];
95 float* delta_dst3 = &pdelta_dst[in_padded.width_ * 3];
96 float* delta_dst4 = &pdelta_dst[in_padded.width_ * 4];
97 for (serial_size_t n = 0; n < nblocks; ++n) {
98 __m256 delta_src = _mm256_broadcast_ps((const __m128*)pdelta_src2);
99 __m256 dst0 = _mm256_loadu_ps(delta_dst0 + 4 * n);
100 __m256 dst1 = _mm256_loadu_ps(delta_dst1 + 4 * n);
101 __m256 dst2 = _mm256_loadu_ps(delta_dst2 + 4 * n);
102 __m256 dst3 = _mm256_loadu_ps(delta_dst3 + 4 * n);
103 __m256 dst4 = _mm256_loadu_ps(delta_dst4 + 4 * n);
104 __m256 delta_src0 = _mm256_permute_ps(delta_src, _MM_SHUFFLE(0, 0, 0, 0));
105 __m256 delta_src1 = _mm256_permute_ps(delta_src, _MM_SHUFFLE(1, 1, 1, 1));
106 __m256 delta_src2 = _mm256_permute_ps(delta_src, _MM_SHUFFLE(2, 2, 2, 2));
107 __m256 delta_src3 = _mm256_permute_ps(delta_src, _MM_SHUFFLE(3, 3, 3, 3));
108 dst0 = madd256_ps(w0a, delta_src0, dst0);
109 dst1 = madd256_ps(w1a, delta_src0, dst1);
110 dst2 = madd256_ps(w2a, delta_src0, dst2);
111 dst3 = madd256_ps(w3a, delta_src0, dst3);
112 dst4 = madd256_ps(w4a, delta_src0, dst4);
113 dst0 = madd256_ps(w0b, delta_src1, dst0);
114 dst1 = madd256_ps(w1b, delta_src1, dst1);
115 dst2 = madd256_ps(w2b, delta_src1, dst2);
116 dst3 = madd256_ps(w3b, delta_src1, dst3);
117 dst4 = madd256_ps(w4b, delta_src1, dst4);
118 dst0 = madd256_ps(w0c, delta_src2, dst0);
119 dst1 = madd256_ps(w1c, delta_src2, dst1);
120 dst2 = madd256_ps(w2c, delta_src2, dst2);
121 dst3 = madd256_ps(w3c, delta_src2, dst3);
122 dst4 = madd256_ps(w4c, delta_src2, dst4);
123 dst0 = madd256_ps(w0d, delta_src3, dst0);
124 _mm256_storeu_ps(delta_dst0 + 4 * n, dst0);
125 dst1 = madd256_ps(w1d, delta_src3, dst1);
126 _mm256_storeu_ps(delta_dst1 + 4 * n, dst1);
127 dst2 = madd256_ps(w2d, delta_src3, dst2);
128 _mm256_storeu_ps(delta_dst2 + 4 * n, dst2);
129 dst3 = madd256_ps(w3d, delta_src3, dst3);
130 _mm256_storeu_ps(delta_dst3 + 4 * n, dst3);
131 dst4 = madd256_ps(w4d, delta_src3, dst4);
132 _mm256_storeu_ps(delta_dst4 + 4 * n, dst4);
133 pdelta_src2 += 4;
134 }
135 for (serial_size_t x = nblocks * 4; x < out.width_; x++) {
136 __m256 delta_src = _mm256_broadcast_ss(pdelta_src + x);
137 __m256 dst0 = _mm256_loadu_ps(delta_dst0 + x);
138 __m256 dst1 = _mm256_loadu_ps(delta_dst1 + x);
139 __m256 dst2 = _mm256_loadu_ps(delta_dst2 + x);
140 __m256 dst3 = _mm256_loadu_ps(delta_dst3 + x);
141 __m256 dst4 = _mm256_loadu_ps(delta_dst4 + x);
142 dst0 = madd256_ps(w0a, delta_src, dst0);
143 dst1 = madd256_ps(w1a, delta_src, dst1);
144 dst2 = madd256_ps(w2a, delta_src, dst2);
145 dst3 = madd256_ps(w3a, delta_src, dst3);
146 dst4 = madd256_ps(w4a, delta_src, dst4);
147 _mm256_storeu_ps(delta_dst0 + x, dst0);
148 _mm256_storeu_ps(delta_dst1 + x, dst1);
149 _mm256_storeu_ps(delta_dst2 + x, dst2);
150 _mm256_storeu_ps(delta_dst3 + x, dst3);
151 _mm256_storeu_ps(delta_dst4 + x, dst4);
152 }
153 pdelta_src += out.width_;
154 pdelta_dst += h_stride2;
155 }
156 }
157 }
158 } else if (out.height_ == 1 && out.width_ == 1) {
159 for (serial_size_t inc = 0; inc < in.depth_; ++inc, pdelta_dst_org += in_padded_area) {
160 float* delta_dst0 = pdelta_dst_org;
161 float* delta_dst1 = &pdelta_dst_org[in_padded.width_ * 1];
162 float* delta_dst2 = &pdelta_dst_org[in_padded.width_ * 2];
163 float* delta_dst3 = &pdelta_dst_org[in_padded.width_ * 3];
164 float* delta_dst4 = &pdelta_dst_org[in_padded.width_ * 4];
165 __m256 dst0 = _mm256_loadu_ps(delta_dst0);
166 __m256 dst1 = _mm256_loadu_ps(delta_dst1);
167 __m256 dst2 = _mm256_loadu_ps(delta_dst2);
168 __m256 dst3 = _mm256_loadu_ps(delta_dst3);
169 __m256 dst4 = _mm256_maskload_ps(delta_dst4, imask);
170
171 // *FROM
172 // ---0 0000
173 // ---1 1111
174 // ---2 2222
175 // ---3 3333
176 // ---4 4444
177 //
178 // *TO
179 // 1110 0000
180 // 3222 2211
181 // 4444 3333
182 // ---- ---4
183 __m256 sum0 = _mm256_blend_ps(
184 dst0,
185 leftShift<20>(dst1),
186 0xE0 /* 0b11100000 */
187 );
188 __m256 sum1 = _mm256_blend_ps(
189 leftShift<28>(dst3),
190 _mm256_blend_ps(leftShift<8>(dst2), rightShift<12>(dst1), 0x03 /* 0b00000011 */),
191 0x7F /* 0b01111111 */
192 );
193 __m256 sum2 = _mm256_blend_ps(
194 leftShift<16>(dst4),
195 rightShift<4>(dst3),
196 0x0F /* 0b00001111 */
197 );
198 __m128 sum3 = _mm256_extractf128_ps(dst4, 1);
199
200 size_t widx = 25 * inc;
201 size_t wstep = 25 * in.depth_;
202
203 if (tbl.is_empty()) {
204 for (serial_size_t outc = 0; outc < out.depth_; outc++, widx+=wstep) {
205 __m256 delta_src = _mm256_broadcast_ss(&curr_delta[outc]);
206 const float* pw = (const float*)&W[widx];
207 __m256 w0 = _mm256_loadu_ps(pw+0);
208 __m256 w1 = _mm256_loadu_ps(pw + 8);
209 __m256 w2 = _mm256_loadu_ps(pw + 16);
210 __m128 w3 = _mm_load_ss(pw + 24);
211 sum0 = madd256_ps(w0, delta_src, sum0);
212 sum1 = madd256_ps(w1, delta_src, sum1);
213 sum2 = madd256_ps(w2, delta_src, sum2);
214 sum3 = madd128_ss(w3, _mm256_castps256_ps128(delta_src), sum3);
215 }
216 }
217 else {
218 for (serial_size_t outc = 0; outc < out.depth_; outc++, widx += wstep) {
219 if (!tbl.is_connected(outc, inc)) {
220 continue;
221 }
222 __m256 delta_src = _mm256_broadcast_ss(&curr_delta[outc]);
223 const float* pw = (const float*)&W[widx];
224 __m256 w0 = _mm256_loadu_ps(pw + 0);
225 __m256 w1 = _mm256_loadu_ps(pw + 8);
226 __m256 w2 = _mm256_loadu_ps(pw + 16);
227 __m128 w3 = _mm_load_ss(pw + 24);
228 sum0 = madd256_ps(w0, delta_src, sum0);
229 sum1 = madd256_ps(w1, delta_src, sum1);
230 sum2 = madd256_ps(w2, delta_src, sum2);
231 sum3 = madd128_ss(w3, _mm256_castps256_ps128(delta_src), sum3);
232 }
233 }
234
235 // *FROM
236 // 1110 0000
237 // 3222 2211
238 // 4444 3333
239 // ---- ---4
240 //
241 // *TO
242 // ---0 0000
243 // ---1 1111
244 // ---2 2222
245 // ---3 3333
246 // ---4 4444
247 dst0 = _mm256_blend_ps(
248 dst0,
249 sum0,
250 0x1F /* 0b00011111 */
251 );
252 dst1 = _mm256_blend_ps(
253 dst1,
254 _mm256_or_ps(
255 rightShift<20>(sum0),
256 leftShift<12>(sum1)
257 ),
258 0x1F /* 0b00011111 */
259 );
260 dst2 = _mm256_blend_ps(
261 dst2,
262 rightShift<8>(sum1),
263 0x1F /* 0b00011111 */
264 );
265 dst3 = _mm256_blend_ps(
266 dst3,
267 _mm256_or_ps(
268 rightShift<28>(sum1),
269 leftShift<4>(sum2)
270 ),
271 0x1F /* 0b00011111 */
272 );
273 dst4 = _mm256_blend_ps(
274 dst4,
275 _mm256_set_m128(
276 sum3,
277 _mm256_extractf128_ps(sum2, 1)
278 ),
279 0x1F /* 0b00011111 */
280 );
281
282 _mm256_storeu_ps(delta_dst0, dst0);
283 _mm256_storeu_ps(delta_dst1, dst1);
284 _mm256_storeu_ps(delta_dst2, dst2);
285 _mm256_storeu_ps(delta_dst3, dst3);
286 _mm256_maskstore_ps(delta_dst4, imask, dst4);
287 } // for
288 } else {
289 for (serial_size_t inc = 0; inc < in.depth_; ++inc, pdelta_dst_org += in_padded_area) {
290 for (serial_size_t outc = 0; outc < out.depth_; outc++) {
291 if (!tbl.is_connected(outc, inc)) continue;
292
293 const float* pw = &W[25 * (in.depth_ * outc + inc)];
294 const float* pdelta_src = &curr_delta[out.get_index(0, 0, outc)];
295 float* pdelta_dst = pdelta_dst_org;
296 __m256 w0a = _mm256_maskload_ps(pw+0, imask);
297 __m256 w1a = _mm256_maskload_ps(pw+5, imask);
298 __m256 w2a = _mm256_maskload_ps(pw+10, imask);
299 __m256 w3a = _mm256_maskload_ps(pw+15, imask);
300 __m256 w4a = _mm256_maskload_ps(pw+20, imask);
301 for (serial_size_t y = 0; y < out.height_; y++) {
302 float* delta_dst0 = pdelta_dst;
303 float* delta_dst1 = &pdelta_dst[in_padded.width_ * 1];
304 float* delta_dst2 = &pdelta_dst[in_padded.width_ * 2];
305 float* delta_dst3 = &pdelta_dst[in_padded.width_ * 3];
306 float* delta_dst4 = &pdelta_dst[in_padded.width_ * 4];
307 for (serial_size_t x = 0; x < out.width_; x++) {
308 __m256 delta_src = _mm256_broadcast_ss(pdelta_src + x);
309 __m256 dst0 = _mm256_loadu_ps(delta_dst0);
310 __m256 dst1 = _mm256_loadu_ps(delta_dst1);
311 __m256 dst2 = _mm256_loadu_ps(delta_dst2);
312 __m256 dst3 = _mm256_loadu_ps(delta_dst3);
313 __m256 dst4 = _mm256_maskload_ps(delta_dst4, imask);
314 dst0 = madd256_ps(w0a, delta_src, dst0);
315 dst1 = madd256_ps(w1a, delta_src, dst1);
316 dst2 = madd256_ps(w2a, delta_src, dst2);
317 dst3 = madd256_ps(w3a, delta_src, dst3);
318 dst4 = madd256_ps(w4a, delta_src, dst4);
319 _mm256_storeu_ps(delta_dst0, dst0);
320 _mm256_storeu_ps(delta_dst1, dst1);
321 _mm256_storeu_ps(delta_dst2, dst2);
322 _mm256_storeu_ps(delta_dst3, dst3);
323 _mm256_maskstore_ps(delta_dst4, imask, dst4);
324 delta_dst0 += w_stride;
325 delta_dst1 += w_stride;
326 delta_dst2 += w_stride;
327 delta_dst3 += w_stride;
328 delta_dst4 += w_stride;
329 } // for x
330 pdelta_src += out.width_;
331 pdelta_dst += h_stride2;
332 } // for y
333 } // for outc
334 } // for inc
335 }
336
337 // accumulate dw
338 if (out.width_ == 1 && out.height_ == 1) {
339 const float* pprev_out = &prev_out[0];
340 for (serial_size_t inc = 0; inc < in.depth_; ++inc, pprev_out += in_padded_area) {
341 VECTORIZE_ALIGN(32) float floats[28];
342 size_t in_padded_width = in_padded.width_;
343 _mm256_store_ps(&floats[0], _mm256_loadu_ps(pprev_out + in_padded_width * 0));
344 _mm256_storeu_ps(&floats[5], _mm256_loadu_ps(pprev_out + in_padded_width * 1));
345 _mm256_storeu_ps(&floats[10], _mm256_loadu_ps(pprev_out + in_padded_width * 2));
346 _mm256_storeu_ps(&floats[15], _mm256_loadu_ps(pprev_out + in_padded_width * 3));
347 _mm256_storeu_ps(&floats[20], _mm256_maskload_ps(pprev_out + in_padded_width * 4, imask));
348 __m256 prevos0 = _mm256_load_ps(&floats[0]);
349 __m256 prevos1 = _mm256_load_ps(&floats[8]);
350 __m256 prevos2 = _mm256_load_ps(&floats[16]);
351 __m128 prevos3 = _mm_load_ss(&floats[24]);
352 serial_size_t widx = 25 * inc;
353 serial_size_t widx_delta = 25 * in.depth_;
354 float* pdW = &dW[widx];
355 for (serial_size_t outc = 0; outc < out.depth_; outc++, pdW += widx_delta) {
356 if (!tbl.is_connected(outc, inc)) {
357 continue;
358 }
359 __m256 delta = _mm256_broadcast_ss(&curr_delta[outc]);
360 __m256 w0 = _mm256_loadu_ps(pdW+0);
361 __m256 w1 = _mm256_loadu_ps(pdW+8);
362 __m256 w2 = _mm256_loadu_ps(pdW + 16);
363 __m128 w3 = _mm_load_ss(pdW + 24);
364 w0 = madd256_ps(prevos0, delta, w0);
365 w1 = madd256_ps(prevos1, delta, w1);
366 w2 = madd256_ps(prevos2, delta, w2);
367 w3 = madd128_ss(prevos3, _mm256_castps256_ps128(delta), w3);
368 _mm256_storeu_ps(pdW + 0, w0);
369 _mm256_storeu_ps(pdW + 8, w1);
370 _mm256_storeu_ps(pdW+16, w2);
371 _mm_store_ss(pdW+24, w3);
372 }
373 }
374 } else {
375 // prepare load-mask beforehand
376 const size_t nblocks = out.width_ >> 3;
377 static const int32_t masks[] = {
378 -1, -1, -1, -1,
379 -1, -1, -1, -1,
380 0, 0, 0, 0,
381 0, 0, 0, 0,
382 };
383 const size_t remainder = out.width_ & 7;
384 __m256i mask = _mm256_loadu_si256((const __m256i*)(masks + 8 - remainder));
385 auto& weight = params.weight;
386 for (serial_size_t inc = 0; inc < in.depth_; ++inc) {
387 for (serial_size_t outc = 0; outc < out.depth_; outc++) {
388
389 if (!tbl.is_connected(outc, inc)) continue;
390 const float* delta = &curr_delta[out.get_index(0, 0, outc)];
391
392 serial_size_t widx = weight.get_index(0, 0, in.depth_ * outc + inc);
393 for (serial_size_t wy = 0; wy < 5 /* weight.height_ */; wy++) {
394 for (serial_size_t wx = 0; wx < 5 /* weight.width_ */; wx++) {
395 const float* prevo = &prev_out[in_padded.get_index(wx, wy, inc)];
396
397 if (w_stride > 1) {
398 float_t dst = float_t(0);
399
400 for (serial_size_t y = 0; y < params.out.height_; y++) {
401 serial_size_t prevo_idx = y * params.in_padded.width_ * params.h_stride;
402 serial_size_t delta_idx = y * params.out.width_;
403
404 for (serial_size_t x = 0; x < params.out.width_; x++) {
405 dst += prevo[prevo_idx + x * params.w_stride] * delta[delta_idx + x];
406 }
407 }
408 dW[widx] += dst;
409 }
410 else {
411 __m128 prev_sum = _mm_load_ss(&dW[widx]);
412 __m256 sum0 = _mm256_setzero_ps();
413 __m256 sum1 = _mm256_setzero_ps();
414 for (serial_size_t y = 0; y < out.height_; y++) {
415 // vectorize::dot
416 const float* pa = prevo + y * in_padded.width_ * params.h_stride;
417 const float* pb = delta + y * out.width_;
418 for (size_t i = 0; i < nblocks; ++i) {
419 __m256 a = _mm256_loadu_ps(pa + 8 * i);
420 __m256 b = _mm256_loadu_ps(pb + 8 * i);
421 sum0 = madd256_ps(a, b, sum0);
422 }
423 if (remainder) {
424 __m256 a = _mm256_maskload_ps(pa + 8 * nblocks, mask);
425 __m256 b = _mm256_maskload_ps(pb + 8 * nblocks, mask);
426 sum1 = madd256_ps(a, b, sum1);
427 }
428 }
429 sum1 = _mm256_and_ps(sum1, _mm256_castsi256_ps(mask));
430 __m256 sum = _mm256_add_ps(sum0, sum1);
431 _mm_store_ss(&dW[widx], _mm_add_ps(prev_sum, hsum256_ps(sum)));
432 }
433 ++widx;
434 }
435 }
436 }
437 }
438 }
439
440 // accumulate db
441 if (params.has_bias) {
442 //fvec_t& db = *in_grad[2];
443
444 if (out.width_ == 1 && out.height_ == 1) {
445 for (serial_size_t outc = 0; outc < out.depth_; outc++) {
446 db[outc] += curr_delta[outc];
447 }
448 } else {
449 for (serial_size_t outc = 0; outc < out.depth_; outc++) {
450 const float *delta = &curr_delta[out.get_index(0, 0, outc)];
451 db[outc] += std::accumulate(delta, delta + out.width_ * out.height_, float(0));
452 }
453 }
454 }
455} // avx_conv2d_5x5_back_kernel float ver
456
457// double ver
458template <typename Allocator>
459void avx_conv2d_5x5_back_kernel(const core::conv_params& params,
460 const std::vector<std::vector<double, Allocator>>& prev_out,
461 const std::vector<double, Allocator>& W,
462 std::vector<std::vector<double, Allocator>>& dW,
463 std::vector<std::vector<double, Allocator>>& db,
464 std::vector<std::vector<double, Allocator>>& curr_delta,
465 std::vector<std::vector<double, Allocator>>& prev_delta) {
466 // backward-pass fallbacks to tiny-backend at float_t == double
467 conv2d_op_internal(prev_out, W, dW, db, curr_delta, prev_delta, params, true);
468}
469
470// float ver
471template <typename Allocator>
472void avx_conv2d_5x5_back_kernel(const core::conv_params& params,
473 const std::vector<std::vector<float, Allocator>>& prev_out,
474 const std::vector<float, Allocator>& W,
475 std::vector<std::vector<float, Allocator>>& dW,
476 std::vector<std::vector<float, Allocator>>& db,
477 std::vector<std::vector<float, Allocator>>& curr_delta,
478 std::vector<std::vector<float, Allocator>>& prev_delta) {
479 for_i(prev_out.size(), [&](int sample) {
480 avx_conv2d_5x5_back_kernel_one(params, prev_out[sample], W, dW[sample], db[sample],
481 curr_delta[sample], &prev_delta[sample]);
482 });
483}
484
485
486#endif // CNN_USE_AVX
487
488inline void
489conv2d_grad_op_avx(const tensor_t& prev_out,
490 const vec_t& W,
491 tensor_t& dW,
492 tensor_t& db,
493 tensor_t& curr_delta,
494 tensor_t& prev_delta,
495 const core::conv_params& params,
496 const bool layer_parallelize) {
497#ifdef CNN_USE_AVX
498 if (params.weight.height_ == 5 && params.weight.width_ == 5) {
499 avx_conv2d_5x5_back_kernel(params, prev_out, W, dW, db, curr_delta, prev_delta);
500 return;
501 }
502#endif
503
504 conv2d_op_internal(prev_out, W, dW, db, curr_delta,
505 prev_delta, params, layer_parallelize);
506}
507
508} // namespace kernels
509} // namespace tiny_dnn