tiny_dnn 1.0.0
A header only, dependency-free deep learning framework in C++11
Loading...
Searching...
No Matches
avx_kernel_common.h
1/*
2 Copyright (c) 2016, Taiga Nomi, Edgar Riba
3 All rights reserved.
4
5 Redistribution and use in source and binary forms, with or without
6 modification, are permitted provided that the following conditions are met:
7 * Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 * Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in the
11 documentation and/or other materials provided with the distribution.
12 * Neither the name of the <organization> nor the
13 names of its contributors may be used to endorse or promote products
14 derived from this software without specific prior written permission.
15
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
17 EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY
20 DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23 ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*/
27#pragma once
28
29#ifndef CNN_USE_AVX
30#error Advanced Vector Extensions required.
31#endif
32
33#ifndef _mm256_set_m128
34#define _mm256_set_m128(va, vb) \
35 _mm256_insertf128_ps(_mm256_castps128_ps256(vb), va, 1)
36#endif
37
38inline __m256 madd256_ps(__m256 a, __m256 b, __m256 c) {
39 return _mm256_add_ps(_mm256_mul_ps(a, b), c);
40}
41inline __m128 madd128_ps(__m128 a, __m128 b, __m128 c) {
42 return _mm_add_ps(_mm_mul_ps(a, b), c);
43}
44inline __m128 madd128_ss(__m128 a, __m128 b, __m128 c) {
45 return _mm_add_ss(_mm_mul_ss(a, b), c);
46}
47inline __m256d madd256_pd(__m256d a, __m256d b, __m256d c) {
48 return _mm256_add_pd(_mm256_mul_pd(a, b), c);
49}
50inline __m128d madd128_pd(__m128d a, __m128d b, __m128d c) {
51 return _mm_add_pd(_mm_mul_pd(a, b), c);
52}
53inline __m128d madd128_sd(__m128d a, __m128d b, __m128d c) {
54 return _mm_add_sd(_mm_mul_sd(a, b), c);
55}
56
57// Horizontally add elements of __m256 type argument (sadly, _mm256_hadd_ps isn't good enough)
58// http://stackoverflow.com/a/13222410/4699324
59// x = ( x7, x6, x5, x4, x3, x2, x1, x0 )
60inline __m128 hsum256_ps(__m256 x) {
61 // hiQuad = ( x7, x6, x5, x4 )
62 const __m128 hiQuad = _mm256_extractf128_ps(x, 1);
63 // loQuad = ( x3, x2, x1, x0 )
64 const __m128 loQuad = _mm256_castps256_ps128(x);
65 // sumQuad = ( x3+x7, x2+x6, x1+x5, x0+x4 )
66 const __m128 sumQuad = _mm_add_ps(loQuad, hiQuad);
67 // loDual = ( -, -, x1+x5, x0+x4 )
68 const __m128 loDual = sumQuad;
69 // hiDual = ( -, -, x3+x7, x2+x6 )
70 const __m128 hiDual = _mm_movehl_ps(sumQuad, sumQuad);
71 // sumDual = ( -, -, x1+x3 + x5+x7, x0+x2 + x4+x6 )
72 const __m128 sumDual = _mm_add_ps(loDual, hiDual);
73 // lo = ( -, -, -, x0+x2 + x4+x6 )
74 const __m128 lo = sumDual;
75 // hi = ( -, -, -, x1+x3 + x5+x7 )
76 const __m128 hi = _mm_shuffle_ps(sumDual, sumDual, 0x1);
77 // sum = ( -, -, -, x0+x1+x2+x3 + x4+x5+x6+x7 )
78 const __m128 sum = _mm_add_ss(lo, hi);
79 return sum;
80}
81
82// Horizontally add elements of each __m256 type arguments at once
83inline __m128 hsum2x256_ps(__m256 a, __m256 b) {
84 // (b3, b2, b1, b0, a3, a2, a1, a0)
85 __m256 x = _mm256_permute2f128_ps(a, b, 0x20);
86 // (b7, b6, b5, b4, a7, a6, a5, a4)
87 __m256 y = _mm256_permute2f128_ps(a, b, 0x31);
88 // (b3+b7, b2+b6, b1+b5, b0+b4, a3+a7, a2+a6, a1+a5, a0+a4)
89 x = _mm256_add_ps(x, y);
90 // (-, -, b3+b7, b2+b6, -, -, a3+a7, a2+a6)
91 y = _mm256_permute_ps(x, _MM_SHUFFLE(3, 2, 3, 2));
92 // (-, -, b1+b5+b3+b7, b0+b4+b2+b6, -, -, a1+a5+a3+a7, a0+a4+a2+a6)
93 x = _mm256_add_ps(x, y);
94 // (-, -, -, b1+b5+b3+b7, -, -, -, a1+a5+a3+a7)
95 y = _mm256_permute_ps(x, _MM_SHUFFLE(1, 1, 1, 1));
96 // (-, -, -, b1+b5+b3+b7+b0+b4+b2+b6, -, -, -, a1+a5+a3+a7+a0+a4+a2+a6)
97 x = _mm256_add_ps(x, y);
98 // (-, -, -, b1+b5+b3+b7+b0+b4+b2+b6)
99 __m128 upper = _mm256_extractf128_ps(x, 1);
100 // (-, -, -, -, -, -, b1+b5+b3+b7+b0+b4+b2+b6, a1+a5+a3+a7+a0+a4+a2+a6)
101 __m128 ret = _mm_unpacklo_ps(_mm256_castps256_ps128(x), upper);
102 return ret;
103}
104
105inline __m128d hsum256_pd(__m256d x) {
106 // hiDual = ( x3, x2 )
107 const __m128d hiDual = _mm256_extractf128_pd(x, 1);
108 // loDual = ( x1, x0 )
109 const __m128d loDual = _mm256_castpd256_pd128(x);
110 // sumQuad = ( x2+x3, x0+x1 )
111 const __m128d sumDual = _mm_add_pd(loDual, hiDual);
112 // sum = ( 0, x0+x1+x2+x3 );
113 const __m128d sum = _mm_hadd_pd(sumDual, _mm_setzero_pd());
114 return sum;
115}
116
117template<int n>
118struct foobar : std::false_type
119{ };
120
121
122// Byte Shift YMM Register Across 128-bit Lanes
123// limitation : shift amount is immediate and is multiples of 4
124
125template <int n>
126inline __m256 leftShift(__m256 a) {
127 static_assert(foobar<n>::value, "unsupported shift amount");
128 return a;
129}
130
131// http://stackoverflow.com/q/19516585
132template <>
133inline __m256 leftShift<4>(__m256 x) {
134 // x = (x7, x6, x5, x4, x3, x2, x1, x0)
135
136 // t0 = (x6, x5, x4, x7, x2, x1, x0, x3)
137 __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(2, 1, 0, 3));
138 // t1 = (x2, x1, x0, x3, 0, 0, 0, 0)
139 __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x08);
140 // y = (x6, x5, x4, x3, x2, x1, x0, 0)
141 __m256 y = _mm256_blend_ps(t0, t1, 0x11);
142 return y;
143}
144
145// http://stackoverflow.com/q/19516585
146template <>
147inline __m256 leftShift<8>(__m256 x) {
148 // x = (x7, x6, x5, x4, x3, x2, x1, x0)
149
150 // t0 = (x5, x4, x7, x6, x1, x0, x3, x2)
151 __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(1, 0, 3, 2));
152 // t1 = (x1, x0, x3, x2, 0, 0, 0, 0)
153 __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x08);
154 // y = (x5, x4, x3, x2, x1, x0, 0, 0)
155 __m256 y = _mm256_blend_ps(t0, t1, 0x33 /* 0b00110011 */ );
156 return y;
157}
158
159template <>
160inline __m256 leftShift<12>(__m256 x) {
161 // x = (x7, x6, x5, x4, x3, x2, x1, x0)
162
163 // t0 = (x4, x7, x6, x5, x0, x3, x2, x1)
164 __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(0, 3, 2, 1));
165 // t1 = (x0, x3, x2, x1, 0, 0, 0, 0)
166 __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x08);
167 // y = (x4, x3, x2, x1, x0, 0, 0, 0)
168 __m256 y = _mm256_blend_ps(t0, t1, 0x77 /* 0b01110111 */ );
169 return y;
170}
171
172template <>
173inline __m256 leftShift<16>(__m256 x) {
174 // x = (x7, x6, x5, x4, x3, x2, x1, x0)
175
176 // y = (x3, x2, x1, x0, 0, 0, 0, 0)
177 __m256 y = _mm256_permute2f128_ps(x, x, 0x08);
178 return y;
179}
180
181template <>
182inline __m256 leftShift<20>(__m256 x) {
183 // x = (x7, x6, x5, x4, x3, x2, x1, x0)
184
185 // t0 = (x6, x5, x4, x7, x2, x1, x0, x3)
186 __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(2, 1, 0, 3));
187 // t1 = (x2, x1, x0, x3, 0, 0, 0, 0)
188 __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x08);
189 // y = (x2, x1, x0, 0, 0, 0, 0, 0)
190 __m256 y = _mm256_blend_ps(t1, _mm256_setzero_ps(), 0x10 /* 0b00010000 */ );
191 return y;
192}
193
194template <>
195inline __m256 leftShift<24>(__m256 x) {
196 // x = (x7, x6, x5, x4, x3, x2, x1, x0)
197
198 // t0 = (x5, x4, x7, x6, x1, x0, x3, x2)
199 __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(1, 0, 3, 2));
200 // t1 = (x1, x0, x3, x2, 0, 0, 0, 0)
201 __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x08);
202 // y = (x1, x0, 0, 0, 0, 0, 0, 0)
203 __m256 y = _mm256_blend_ps(_mm256_setzero_ps(), t1, 0xC0 /* 0b11000000 */ );
204 return y;
205}
206
207template <>
208inline __m256 leftShift<28>(__m256 x) {
209 // x = (x7, x6, x5, x4, x3, x2, x1, x0)
210
211 // t0 = (x4, x7, x6, x5, x0, x3, x2, x1)
212 __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(0, 3, 2, 1));
213 // t1 = (x0, x3, x2, x1, 0, 0, 0, 0)
214 __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x08);
215 // y = (x0, 0, 0, 0, 0, 0, 0, 0)
216 __m256 y = _mm256_blend_ps(_mm256_setzero_ps(), t1, 0x80 /* 0b10000000 */ );
217 return y;
218}
219
220template <int n>
221inline __m256 rightShift(__m256 a)
222{
223 static_assert(foobar<n>::value, "unsupported shift amount");
224 return a;
225}
226
227// http://stackoverflow.com/a/19532415/4699324
228template <>
229inline __m256 rightShift<4>(__m256 x) {
230 // x = (x7, x6, x5, x4, x3, x2, x1, x0)
231
232 // t0 = (x4, x7, x6, x5, x0, x3, x2, x1)
233 __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(0, 3, 2, 1));
234 // t1 = (0, 0, 0, 0, x4, x7, x6, x5)
235 __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x81);
236
237 // ( -, x7, x6, x5, -, x3, x2, x1)
238 // ( 0, -, -, -, x4, -, -, -)
239 // y = ( 0, x7, x6, x5, x4, x3, x2, x1)
240 __m256 y = _mm256_blend_ps(t0, t1, 0x88 /* 0b10001000 */ );
241 return y;
242}
243
244template <>
245inline __m256 rightShift<8>(__m256 x) {
246 // x = (x7, x6, x5, x4, x3, x2, x1, x0)
247
248 // t0 = (x5, x4, x7, x6, x1, x0, x3, x2)
249 __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(1, 0, 3, 2));
250 // t1 = (0, 0, 0, 0, x5, x4, x7, x6)
251 __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x81);
252
253 // ( -, -, x7, x6, -, -, x3, x2)
254 // ( 0, 0, -, -, x5, x4, -, -)
255 // y = ( 0, 0, x7, x6, x5, x4, x3, x2)
256 __m256 y = _mm256_blend_ps(t0, t1, 0xCC /* 0b11001100 */ );
257 return y;
258}
259
260template <>
261inline __m256 rightShift<12>(__m256 x) {
262 // x = (x7, x6, x5, x4, x3, x2, x1, x0)
263
264 // t0 = (x6, x5, x4, x7, x2, x1, x0, x3)
265 __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(2, 1, 0, 3));
266 // t1 = ( 0, 0, 0, 0, x6, x5, x4, x7)
267 __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x81);
268
269 // ( -, -, -, x7, -, -, -, x3)
270 // ( 0, 0, 0, -, x6, x5, x4, -)
271 // y = ( 0, 0, 0, x7, x6, x5, x4, x3)
272 __m256 y = _mm256_blend_ps(t0, t1, 0xEE /* 0b11101110 */ );
273 return y;
274}
275
276template <>
277inline __m256 rightShift<16>(__m256 x)
278{
279 // x = (x7, x6, x5, x4, x3, x2, x1, x0)
280
281 // y = ( 0, 0, 0, 0, x7, x6, x5, x4)
282 __m256 y = _mm256_permute2f128_ps(x, x, 0x81);
283 return y;
284}
285
286template <>
287inline __m256 rightShift<20>(__m256 x) {
288 // x = (x7, x6, x5, x4, x3, x2, x1, x0)
289
290 // t0 = (x4, x7, x6, x5, x0, x3, x2, x1)
291 __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(0, 3, 2, 1));
292 // t1 = ( 0, 0, 0, 0, x4, x7, x6, x5)
293 __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x81);
294
295 // ( -, -, -, -, -, x7, x6, x5)
296 // ( 0, 0, 0, 0, 0, -, -, -)
297 // y = ( 0, 0, 0, 0, 0, x7, x6, x5)
298 __m256 y = _mm256_blend_ps(t1, _mm256_setzero_ps(), 0xF8 /* 0b11111000 */ );
299 return y;
300}
301
302template <>
303inline __m256 rightShift<24>(__m256 x) {
304 // x = (x7, x6, x5, x4, x3, x2, x1, x0)
305
306 // t0 = (x5, x4, x7, x6, x1, x0, x3, x2)
307 __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(1, 0, 3, 2));
308 // t1 = ( 0, 0, 0, 0, x5, x4, x7, x6)
309 __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x81);
310
311 // ( -, -, -, -, -, -, x7, x6)
312 // ( 0, 0, 0, 0, 0, 0, -, -)
313 // y = ( 0, 0, 0, 0, 0, 0, x7, x6)
314 __m256 y = _mm256_blend_ps(t1, _mm256_setzero_ps(), 0xFC /* 0b11111100 */ );
315 return y;
316}
317
318template <>
319inline __m256 rightShift<28>(__m256 x) {
320 // x = (x7, x6, x5, x4, x3, x2, x1, x0)
321
322 // t0 = (x6, x5, x4, x7, x2, x1, x0, x3)
323 __m256 t0 = _mm256_permute_ps(x, _MM_SHUFFLE(2, 1, 0, 3));
324 // t1 = ( 0, 0, 0, 0, x6, x5, x4, x7)
325 __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x81);
326
327 // ( -, -, -, -, -, -, -, x7)
328 // ( 0, 0, 0, 0, 0, 0, 0, -)
329 // y = ( 0, 0, 0, 0, 0, 0, 0, x7)
330 __m256 y = _mm256_blend_ps(t1, _mm256_setzero_ps(), 0xFE /* 0b11111110 */ );
331 return y;
332}
333
Definition avx_kernel_common.h:119