blob: 3f753776362f0f963028131c0b4eddfbf240bb8f [file] [log] [blame]
Manuel Bottini7b9998d2019-10-21 17:59:07 +01001/*
2 * Copyright (c) 2019 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
26#if defined(ARG_MAX)
27#define CONDITION_TO_USE(x, y) ISGREATER(x, y)
28#elif defined(ARG_MIN)
29#define CONDITION_TO_USE(x, y) ISLESS(x, y)
30#else // !(defined(ARG_MAX) || defined(ARG_MIN))
31#error "Unsupported reduction operation!"
32#endif // defined(ARG_MAX)
33
34#if defined(DATA_TYPE_OUTPUT)
35#if defined(WIDTH)
36#if defined(ARG_MIN)
37#if defined(PREV_OUTPUT)
38/** Find index minimum value of a vector
39 *
40 * @param[in] input Pointer to the first value.
41 *
42 * @return index of the vector.
43 */
44inline DATA_TYPE_OUTPUT arg_idx_min_prev_out(__global const DATA_TYPE *input, __global const DATA_TYPE_OUTPUT *prev_res, const int x_idx)
45{
46 int end_elem = (x_idx + 1) * 16;
47 if(end_elem > WIDTH)
48 {
49 end_elem = WIDTH - x_idx * 16;
50 }
51 DATA_TYPE_OUTPUT res = prev_res[0];
52 for(int x_v = 1; x_v < end_elem; ++x_v)
53 {
54 res = select(res, prev_res[x_v], *(input + prev_res[x_v]) < * (input + res));
55 }
56 return res;
57}
58#else // !defined(PREV_OUTPUT)
59/** Find index minimum value of a vector
60 *
61 * @param[in] input Pointer to the first value.
62 *
63 * @return index of the vector.
64 */
65inline DATA_TYPE_OUTPUT arg_idx_min(__global const DATA_TYPE *input, const int x_idx)
66{
67#if WIDTH < 16
68 DATA_TYPE_OUTPUT res = 0;
69 for(DATA_TYPE_OUTPUT x_v = res + 1; x_v < WIDTH; ++x_v)
70 {
71 res = select(res, x_v, *(input + x_v) < * (input + res));
72 }
73 return res;
74#else // WIDTH >= 16
75 int x_elem = x_idx * 16;
76 const int x_goback = select(0, 16 - WIDTH % 16, x_elem + 16 > WIDTH);
77 x_elem -= x_goback;
78
79 VEC_DATA_TYPE(DATA_TYPE, 16)
80 in = vload16(0, input - x_goback);
81 VEC_DATA_TYPE(DATA_TYPE_OUTPUT, 16)
82 res = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 };
83
84 VEC_DATA_TYPE(COND_DATA_TYPE, 8)
85 idx_sel = (in.s01234567 <= in.s89abcdef);
86 in.s01234567 = select(in.s89abcdef, in.s01234567, idx_sel);
87 res.s01234567 = select(res.s89abcdef, res.s01234567, CONVERT(idx_sel, int8));
88
89 idx_sel.s0123 = (in.s0123 < in.s4567) || (in.s0123 == in.s4567 && CONVERT((res.s0123 < res.s4567), VEC_DATA_TYPE(COND_DATA_TYPE, 4)));
90 in.s0123 = select(in.s4567, in.s0123, idx_sel.s0123);
91 res.s0123 = select(res.s4567, res.s0123, CONVERT(idx_sel.s0123, int4));
92
93 idx_sel.s01 = (in.s01 < in.s23) || (in.s01 == in.s23 && CONVERT((res.s01 < res.s23), VEC_DATA_TYPE(COND_DATA_TYPE, 2)));
94 in.s01 = select(in.s23, in.s01, idx_sel.s01);
95 res.s01 = select(res.s23, res.s01, CONVERT(idx_sel.s01, int2));
96
97 idx_sel.s0 = (in.s0 < in.s1) || (in.s0 == in.s1 && CONVERT((res.s0 < res.s1), COND_DATA_TYPE));
98 res.s0 = select(res.s1, res.s0, CONVERT(idx_sel.s0, int));
99
100 return res.s0 + x_elem;
101#endif // WIDTH < 16
102}
103#endif // defined(PREV_OUTPUT)
104#endif // defined(ARG_MIN)
105#if defined(ARG_MAX)
106#if defined(PREV_OUTPUT)
107/** Find index maximum value of a vector
108 *
109 * @param[in] input Pointer to the first value.
110 *
111 * @return index of the vector.
112 */
113inline DATA_TYPE_OUTPUT arg_idx_max_prev_out(__global const DATA_TYPE *input, __global const DATA_TYPE_OUTPUT *prev_res, const int x_idx)
114{
115 int end_elem = (x_idx + 1) * 16;
116 if(end_elem > WIDTH)
117 {
118 end_elem = WIDTH - x_idx * 16;
119 }
120 DATA_TYPE_OUTPUT res = prev_res[0];
121 for(int x_v = 1; x_v < end_elem; ++x_v)
122 {
123 res = select(res, prev_res[x_v], *(input + prev_res[x_v]) > *(input + res));
124 }
125 return res;
126}
127#else // !defined(PREV_OUTPUT)
128/** Find index maximum value of a vector
129 *
130 * @param[in] input Pointer to the first value.
131 *
132 * @return index of the vector.
133 */
134inline DATA_TYPE_OUTPUT arg_idx_max(__global const DATA_TYPE *input, const int x_idx)
135{
136#if WIDTH < 16
137 DATA_TYPE_OUTPUT res = 0;
138 for(DATA_TYPE_OUTPUT x_v = res + 1; x_v < WIDTH; ++x_v)
139 {
140 res = select(res, x_v, *(input + x_v) > *(input + res));
141 }
142 return res;
143#else // WIDTH >= 16
144 int x_elem = x_idx * 16;
145 const int x_goback = select(0, 16 - WIDTH % 16, x_elem + 16 > WIDTH);
146 x_elem -= x_goback;
147
148 VEC_DATA_TYPE(DATA_TYPE, 16)
149 in = vload16(0, input - x_goback);
150 VEC_DATA_TYPE(DATA_TYPE_OUTPUT, 16)
151 res = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 };
152
153 VEC_DATA_TYPE(COND_DATA_TYPE, 8)
154 idx_sel = (in.s01234567 >= in.s89abcdef);
155 in.s01234567 = select(in.s89abcdef, in.s01234567, idx_sel);
156 res.s01234567 = select(res.s89abcdef, res.s01234567, CONVERT(idx_sel, int8));
157
158 idx_sel.s0123 = (in.s0123 > in.s4567) || (in.s0123 == in.s4567 && CONVERT((res.s0123 < res.s4567), VEC_DATA_TYPE(COND_DATA_TYPE, 4)));
159 in.s0123 = select(in.s4567, in.s0123, idx_sel.s0123);
160 res.s0123 = select(res.s4567, res.s0123, CONVERT(idx_sel.s0123, int4));
161
162 idx_sel.s01 = (in.s01 > in.s23) || (in.s01 == in.s23 && CONVERT((res.s01 < res.s23), VEC_DATA_TYPE(COND_DATA_TYPE, 2)));
163 in.s01 = select(in.s23, in.s01, idx_sel.s01);
164 res.s01 = select(res.s23, res.s01, CONVERT(idx_sel.s01, int2));
165
166 idx_sel.s0 = (in.s0 > in.s1) || (in.s0 == in.s1 && CONVERT((res.s0 < res.s1), COND_DATA_TYPE));
167 res.s0 = select(res.s1, res.s0, CONVERT(idx_sel.s0, int));
168
169 return res.s0 + x_elem;
170#endif // WIDTH < 16
171}
172#endif // defined(PREV_OUTPUT)
173#endif // defined(ARG_MAX)
174
175/** This kernel performs parallel reduction given an operation on x-axis.
176 *
177 * @note In case the results of previous stages are passed the flag PREV_OUTPUT has to be passed using -DPREV_OUTPUT
178 * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
179 * @note The data type of the output must be passed at compile time using -DDATA_TYPE_OUTPUT: e.g. -DDATA_TYPE_OUTPUT=uint
180 * @note The arg_max flag must be passed at compile time using -DARG_MAX if we want to compute the ArgMax
181 * @note The arg_min flag must be passed at compile time using -DARG_MIN if we want to compute the ArgMin
182 *
183 * @param[in] src_ptr Pointer to the source tensor. Supported data types: S32/F16/F32
184 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
185 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
186 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
187 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
188 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
189 * @param[in] prev_res_ptr (Optional) Pointer to previous results tensor. Supported data types: U32/S32
190 * @param[in] prev_res_stride_x (Optional) Stride of the output tensor in X dimension (in bytes)
191 * @param[in] prev_res_step_x (Optional) prev_res_stride_x * number of elements along X processed per workitem(in bytes)
192 * @param[in] prev_res_stride_y (Optional) Stride of the output tensor in Y dimension (in bytes)
193 * @param[in] prev_res_step_y (Optional) prev_res_stride_y * number of elements along Y processed per workitem(in bytes)
194 * @param[in] prev_res_offset_first_element_in_bytes (Optional) The offset of the first element in the previous results tensor
195 * @param[in] partial_res_ptr The local buffer to hold partial result values. Supported data types: U32/S32
196 * @param[in] partial_res_stride_x Stride of the output tensor in X dimension (in bytes)
197 * @param[in] partial_res_step_x partial_res_stride_x * number of elements along X processed per workitem(in bytes)
198 * @param[in] partial_res_stride_y Stride of the output tensor in Y dimension (in bytes)
199 * @param[in] partial_res_step_y partial_res_stride_y * number of elements along Y processed per workitem(in bytes)
200 * @param[in] partial_res_offset_first_element_in_bytes The offset of the first element in the source tensor
201 * @param[in] local_results Local buffer for storing the partial result
202 */
203__kernel void arg_min_max_x(
204 IMAGE_DECLARATION(src),
205#if defined(PREV_OUTPUT)
206 IMAGE_DECLARATION(prev_res),
207#endif // defined(PREV_OUTPUT)
208 IMAGE_DECLARATION(partial_res),
209 __local DATA_TYPE_OUTPUT *local_results)
210{
211#if defined(PREV_OUTPUT)
212 Image src = CONVERT_TO_IMAGE_STRUCT_NO_STEP(src);
213 Image prev_res = CONVERT_TO_IMAGE_STRUCT(prev_res);
214#else // !defined(PREV_OUTPUT)
215 Image src = CONVERT_TO_IMAGE_STRUCT(src);
216#endif // defined(PREV_OUTPUT)
217 Image partial_res = CONVERT_TO_IMAGE_STRUCT(partial_res);
218
219 unsigned int lsize = get_local_size(0);
220 unsigned int lid = get_local_id(0);
221
222 const uint x_idx = get_global_id(0);
223 const uint y_idx = get_global_id(1);
224 const __global DATA_TYPE *src_in_row = (const __global DATA_TYPE *)(src_ptr + src_offset_first_element_in_bytes + y_idx * src_step_y);
225
226 for(unsigned int y = 0; y < get_local_size(1); ++y)
227 {
228#if defined(ARG_MAX)
229#if defined(PREV_OUTPUT)
230 local_results[lid] = arg_idx_max_prev_out(src_in_row, (__global DATA_TYPE_OUTPUT *)offset(&prev_res, 0, y), x_idx);
231#else // !defined(PREV_OUTPUT)
232 local_results[lid] = arg_idx_max((__global DATA_TYPE *)offset(&src, 0, y), x_idx);
233#endif // defined(PREV_OUTPUT)
234#else // defined(ARG_MIN)
235#if defined(PREV_OUTPUT)
236 local_results[lid] = arg_idx_min_prev_out(src_in_row, (__global DATA_TYPE_OUTPUT *)offset(&prev_res, 0, y), x_idx);
237#else // !defined(PREV_OUTPUT)
238 local_results[lid] = arg_idx_min((__global DATA_TYPE *)offset(&src, 0, y), x_idx);
239#endif // defined(PREV_OUTPUT)
240#endif // defined(ARG_MAX) || defined(ARG_MIN)
241
242 barrier(CLK_LOCAL_MEM_FENCE);
243
244 // Perform parallel reduction
245 for(unsigned int i = lsize >> 1; i > 0; i >>= 1)
246 {
247 if(lid < i)
248 {
249 DATA_TYPE tmp0 = *(src_in_row + local_results[lid]);
250 DATA_TYPE tmp1 = *(src_in_row + local_results[lid + i]);
251#if defined(ARG_MAX)
252 local_results[lid] = select(
253 local_results[lid],
254 local_results[lid + i],
255 ((tmp0 == tmp1) && (local_results[lid + i] < local_results[lid])) || (tmp0 < tmp1));
256#else // defined(ARG_MIN)
257 local_results[lid] = select(
258 local_results[lid],
259 local_results[lid + i],
260 ((tmp0 == tmp1) && (local_results[lid + i] < local_results[lid])) || (tmp0 > tmp1));
261#endif // defined(ARG_MAX) || defined(ARG_MIN)
262 }
263 barrier(CLK_LOCAL_MEM_FENCE);
264 }
265
266 if(lid == 0)
267 {
268 ((__global DATA_TYPE_OUTPUT *)offset(&partial_res, get_group_id(0), y))[0] = local_results[0];
269 }
270 }
271}
272#endif // defined(WIDTH)
273
274#if defined(HEIGHT)
275/** This kernel performs reduction on y-axis.
276 *
277 * @note The input data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
278 * @note The data type of the output must be passed at compile time using -DDATA_TYPE_OUTPUT: e.g. -DDATA_TYPE_OUTPUT=uint
279 * @note The data type of the intermediate results must be passed at compile time using -DDATA_TYPE_PROMOTED: e.g. -DDATA_TYPE_PROMOTED=uint
280 * @note The height size must be passed at compile time using -DHEIGHT e.g. -DHEIGHT=128
281 *
282 * @param[in] src_ptr Pointer to the source tensor. Supported data types: S32/F16/F32
283 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
284 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
285 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
286 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
287 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
288 * @param[in] output_ptr The local buffer to hold sumed values. Supported data types: U32/S32
289 * @param[in] output_stride_x Stride of the output tensor in X dimension (in bytes)
290 * @param[in] output_step_x output_stride_x * number of elements along X processed per workitem(in bytes)
291 * @param[in] output_stride_y Stride of the output tensor in Y dimension (in bytes)
292 * @param[in] output_step_y output_stride_y * number of elements along Y processed per workitem(in bytes)
293 * @param[in] output_offset_first_element_in_bytes The offset of the first element in the source tensor
294 */
295__kernel void arg_min_max_y(
296 IMAGE_DECLARATION(src),
297 IMAGE_DECLARATION(output))
298{
299 Image src = CONVERT_TO_IMAGE_STRUCT(src);
300 Image output = CONVERT_TO_IMAGE_STRUCT(output);
301
302 VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 16)
303 res = CONVERT(vload16(0, (__global DATA_TYPE *)offset(&src, 0, 0)), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 16));
304
305 VEC_DATA_TYPE(DATA_TYPE_OUTPUT, 16)
306 indx = 0;
307 for(unsigned int y = 1; y < HEIGHT; ++y)
308 {
309 VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 16)
310 in = CONVERT(vload16(0, (__global DATA_TYPE *)offset(&src, 0, y)), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 16));
311
312 VEC_DATA_TYPE(DATA_TYPE_OUTPUT, 16)
313 cond_conv = CONVERT(CONDITION_TO_USE(in, res), VEC_DATA_TYPE(DATA_TYPE_OUTPUT, 16));
314 indx = select(indx, y, cond_conv);
315 res = select(res, in, CONDITION_TO_USE(in, res));
316 }
317
318 // Store result
319 vstore16(indx, 0, (__global DATA_TYPE_OUTPUT *)output.ptr);
320}
321#endif // defined(HEIGHT)
322
323#if defined(DEPTH)
324/** This kernel performs reduction on z-axis.
325 *
326 * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
327 * @note The data type of the intermediate results must be passed at compile time using -DDATA_TYPE_PROMOTED: e.g. -DDATA_TYPE_PROMOTED=uint
328 * @note The depth size must be passed at compile time using -DDEPTH e.g. -DDEPTH=128
329 *
330 * @param[in] input_ptr Pointer to the source tensor. Supported data types: S32/F16/F32
331 * @param[in] input_stride_x Stride of the source tensor in X dimension (in bytes)
332 * @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
333 * @param[in] input_stride_y Stride of the source tensor in Y dimension (in bytes)
334 * @param[in] input_step_y input_stride_y * number of elements along Y processed per workitem(in bytes)
335 * @param[in] input_stride_z Stride of the source tensor in Z dimension (in bytes)
336 * @param[in] input_step_z input_stride_z * number of elements along Z processed per workitem(in bytes)
337 * @param[in] input_offset_first_element_in_bytes The offset of the first element in the source tensor
338 * @param[in] output_ptr The local buffer to hold sumed values. Supported data types: U32/S32
339 * @param[in] output_stride_x Stride of the output tensor in X dimension (in bytes)
340 * @param[in] output_step_x output_stride_x * number of elements along X processed per workitem(in bytes)
341 * @param[in] output_stride_y Stride of the output tensor in Y dimension (in bytes)
342 * @param[in] output_step_y output_stride_y * number of elements along Y processed per workitem(in bytes)
343 * @param[in] output_stride_z Stride of the output tensor in Z dimension (in bytes)
344 * @param[in] output_step_z output_stride_z * number of elements along Z processed per workitem(in bytes)
345 * @param[in] output_offset_first_element_in_bytes The offset of the first element in the source tensor
346 */
347__kernel void arg_min_max_z(
348 TENSOR3D_DECLARATION(input),
349 TENSOR3D_DECLARATION(output))
350{
351 Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
352 Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
353
354 VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 16)
355 res = CONVERT(vload16(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 0, 0)), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 16));
356
357 VEC_DATA_TYPE(DATA_TYPE_OUTPUT, 16)
358 indx = 0;
359 for(DATA_TYPE_OUTPUT z = 1; z < DEPTH; ++z)
360 {
361 VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 16)
362 in = CONVERT(vload16(0, (__global DATA_TYPE *)tensor3D_offset(&input, 0, 0, z)), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 16));
363
364 VEC_DATA_TYPE(DATA_TYPE_OUTPUT, 16)
365 cond_conv = CONVERT(CONDITION_TO_USE(in, res), VEC_DATA_TYPE(DATA_TYPE_OUTPUT, 16));
366 indx = select(indx, z, cond_conv);
367 res = select(res, in, CONDITION_TO_USE(in, res));
368 }
369
370 // Store result
371 vstore16(indx, 0, (__global DATA_TYPE_OUTPUT *)output.ptr);
372}
373#endif /* defined(DEPTH) */
374
375#if defined(BATCH) && defined(DEPTH)
376/** This kernel performs reduction on w-axis.
377 *
378 * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
379 * @note The data type of the intermediate results must be passed at compile time using -DDATA_TYPE_PROMOTED: e.g. -DDATA_TYPE_PROMOTED=uint
380 * @note The batch size must be passed at compile time using -DBATCH e.g. -DBATCH=128
381 * @note The depth size must be passed at compile time using -DBATCH e.g. -DDEPTH=128
382 *
383 * @param[in] input_ptr Pointer to the source tensor. Supported data types: S32/F16/F32
384 * @param[in] input_stride_x Stride of the source tensor in X dimension (in bytes)
385 * @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
386 * @param[in] input_stride_y Stride of the source tensor in Y dimension (in bytes)
387 * @param[in] input_step_y input_stride_y * number of elements along Y processed per workitem(in bytes)
388 * @param[in] input_stride_z Stride of the source tensor in Z dimension (in bytes)
389 * @param[in] input_step_z input_stride_z * number of elements along Z processed per workitem(in bytes)
390 * @param[in] input_stride_w Stride of the source tensor in W dimension (in bytes)
391 * @param[in] input_step_w input_stride_w * number of elements along W processed per workitem(in bytes)
392 * @param[in] input_offset_first_element_in_bytes The offset of the first element in the source tensor
393 * @param[in] output_ptr The local buffer to hold sumed values. Supported data types: U32/S32
394 * @param[in] output_stride_x Stride of the output tensor in X dimension (in bytes)
395 * @param[in] output_step_x output_stride_x * number of elements along X processed per workitem(in bytes)
396 * @param[in] output_stride_y Stride of the output tensor in Y dimension (in bytes)
397 * @param[in] output_step_y output_stride_y * number of elements along Y processed per workitem(in bytes)
398 * @param[in] output_stride_z Stride of the output tensor in Z dimension (in bytes)
399 * @param[in] output_step_z output_stride_z * number of elements along Z processed per workitem(in bytes)
400 * @param[in] output_stride_w Stride of the output tensor in W dimension (in bytes)
401 * @param[in] output_step_w output_stride_w * number of elements along W processed per workitem(in bytes)
402 * @param[in] output_offset_first_element_in_bytes The offset of the first element in the source tensor
403 */
404__kernel void arg_min_max_w(
405 TENSOR4D_DECLARATION(input),
406 TENSOR4D_DECLARATION(output))
407{
408 Tensor4D input = CONVERT_TO_TENSOR4D_STRUCT(input, DEPTH);
409 Tensor4D output = CONVERT_TO_TENSOR4D_STRUCT(output, DEPTH);
410
411 VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 16)
412 res = CONVERT(vload16(0, (__global DATA_TYPE *)tensor4D_offset(&input, 0, 0, 0, 0)), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 16));
413
414 VEC_DATA_TYPE(DATA_TYPE_OUTPUT, 16)
415 indx = 0;
416 for(DATA_TYPE_OUTPUT w = 1; w < BATCH; ++w)
417 {
418 VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 16)
419 in = CONVERT(vload16(0, (__global DATA_TYPE *)tensor4D_offset(&input, 0, 0, 0, w)), VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 16));
420
421 VEC_DATA_TYPE(DATA_TYPE_OUTPUT, 16)
422 cond_conv = CONVERT(CONDITION_TO_USE(in, res), VEC_DATA_TYPE(DATA_TYPE_OUTPUT, 16));
423 indx = select(indx, w, cond_conv);
424 res = select(res, in, CONDITION_TO_USE(in, res));
425 }
426
427 // Store result
428 vstore16(indx, 0, (__global DATA_TYPE_OUTPUT *)output.ptr);
429}
430#endif /* defined(BATCH) && defined(DEPTH) */
431#endif // defined(DATA_TYPE_OUTPUT)