blob: 767cf4c4f79d905816049fb8c571b5d6d81fba1c [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Sang-Hoon Park62eeb532019-10-29 13:13:19 +00002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Georgios Pinitase5f8fd62017-06-23 18:03:44 +010026#define MAX_OP(x, y, type, size) max((x), (y))
27#define ADD_OP(x, y, type, size) ((x) + (y))
28#define SUB_OP(x, y, type, size) ((x) - (y))
Pablo Palmier48a60f92017-10-18 11:03:08 +010029#define MUL_OP(x, y, type, size) ((x) * (y))
Georgios Pinitase5f8fd62017-06-23 18:03:44 +010030#define DIV_OP(x, y, type, size) ((x) / (y))
31#define EXP_OP(x, type, size) exp((x))
32
Anthony Barbierac69aa12017-07-03 17:39:37 +010033#ifdef USE_F16
Georgios Pinitase5f8fd62017-06-23 18:03:44 +010034#define MINVAL -HALF_MAX
35#define SELECT_DATA_TYPE short
Anthony Barbierac69aa12017-07-03 17:39:37 +010036#else /* USE_F16 */
Georgios Pinitase5f8fd62017-06-23 18:03:44 +010037#define MINVAL -FLT_MAX
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038#define SELECT_DATA_TYPE int
Anthony Barbierac69aa12017-07-03 17:39:37 +010039#endif /* USE_F16 */
Georgios Pinitase5f8fd62017-06-23 18:03:44 +010040
Chunosovd6afedc2017-11-06 22:09:45 +070041/* Number of workitems in dimension 0. */
42#if !defined(GRID_SIZE)
43#define GRID_SIZE 1
44#endif /* !defined(GRID_SIZE) */
45
46/* Vector size, i.e. number of vector elements. */
47#if VECTOR_SIZE == 2
48__constant VEC_DATA_TYPE(DATA_TYPE, 2) type_min_ = (VEC_DATA_TYPE(DATA_TYPE, 2))(MINVAL);
49__constant uint2 idx__ = (uint2)(0, 1);
50
51#elif VECTOR_SIZE == 4
52__constant VEC_DATA_TYPE(DATA_TYPE, 4) type_min_ = (VEC_DATA_TYPE(DATA_TYPE, 4))(MINVAL);
53__constant uint4 idx__ = (uint4)(0, 1, 2, 3);
54
55#elif VECTOR_SIZE == 8
56__constant VEC_DATA_TYPE(DATA_TYPE, 8) type_min_ = (VEC_DATA_TYPE(DATA_TYPE, 8))(MINVAL);
57__constant uint8 idx__ = (uint8)(0, 1, 2, 3, 4, 5, 6, 7);
58
59#else /* VECTOR_SIZE DEFAULT */
60#define VECTOR_SIZE 16
61#define LOG_VECTOR_SIZE 4
62__constant VEC_DATA_TYPE(DATA_TYPE, 16) type_min_ = (VEC_DATA_TYPE(DATA_TYPE, 16))(MINVAL);
63__constant uint16 idx__ = (uint16)(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
64
65#endif /* VECTOR_SIZE END */
66
67// TODO (COMPMID-661): Remove if the non-fused kernels are removed
Anthony Barbier6ff3b192017-09-04 18:44:23 +010068__constant VEC_DATA_TYPE(DATA_TYPE, 16) type_min = (VEC_DATA_TYPE(DATA_TYPE, 16))(MINVAL);
69__constant uint16 idx16 = (uint16)(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
Chunosovd6afedc2017-11-06 22:09:45 +070070__constant uint4 idx4 = (uint4)(0, 1, 2, 3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010071
Anthony Barbier6ff3b192017-09-04 18:44:23 +010072/** Divides all the values of the input tensor by the sum calculated from softmax_layer_shift_exp_sum kernel.
73 *
74 * @note Datatype must be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
75 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010076 * @param[in] src_ptr Pointer to the source tensor slice. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +010077 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
78 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
79 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
80 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
steniu010d523cc2017-07-13 14:24:23 +010081 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
82 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010083 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
Georgios Pinitase5f8fd62017-06-23 18:03:44 +010084 * @param[in] sum_ptr Pointer to the sum values tensor slice. Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +010085 * @param[in] sum_stride_x Stride of the sum values tensor in X dimension (in bytes)
86 * @param[in] sum_step_x sum_stride_x * number of elements along X processed per workitem(in bytes)
87 * @param[in] sum_stride_y Stride of the sum values tensor in Y dimension (in bytes)
88 * @param[in] sum_step_y sum_stride_y * number of elements along Y processed per workitem(in bytes)
steniu010d523cc2017-07-13 14:24:23 +010089 * @param[in] sum_stride_z Stride of the sum values tensor in Z dimension (in bytes)
90 * @param[in] sum_step_z sum_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010091 * @param[in] sum_offset_first_element_in_bytes The offset of the first element in the sum values tensor
Georgios Pinitase5f8fd62017-06-23 18:03:44 +010092 * @param[out] dst_ptr Pointer to the destination tensor slice. Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +010093 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
94 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
95 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
96 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
steniu010d523cc2017-07-13 14:24:23 +010097 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
98 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010099 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
100 */
101__kernel void softmax_layer_norm(
steniu010d523cc2017-07-13 14:24:23 +0100102 TENSOR3D_DECLARATION(src),
103 TENSOR3D_DECLARATION(sum),
104 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100105{
steniu010d523cc2017-07-13 14:24:23 +0100106 Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
107 Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
108 Image sum = CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(sum);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100109
110 // Load max value of 1D logits vector (row)
111 DATA_TYPE sum_val = *((__global DATA_TYPE *)offset(&sum, 0, get_global_id(1)));
112 VEC_DATA_TYPE(DATA_TYPE, 16)
113 data = vload16(0, (__global DATA_TYPE *)offset(&src, 0, 0));
Sang-Hoon Park62eeb532019-10-29 13:13:19 +0000114#ifdef LOG_SOFTMAX
115 vstore16(SUB_OP(data, sum_val, DATA_TYPE, 16), 0, (__global DATA_TYPE *)offset(&dst, 0, 0));
116#else /* LOG_SOFTMAX */
Georgios Pinitase5f8fd62017-06-23 18:03:44 +0100117 vstore16(DIV_OP(data, sum_val, DATA_TYPE, 16), 0, (__global DATA_TYPE *)offset(&dst, 0, 0));
Sang-Hoon Park62eeb532019-10-29 13:13:19 +0000118#endif /* LOG_SOFTMAX */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100119}
Chunosovd6afedc2017-11-06 22:09:45 +0700120
121/** Identifies the maximum value across the 1st dimension and shifts the values of the input tensor by this maximum value,
122 * then gets the exponent of each element as sums all elements across each row.
123 *
124 * @note Datatype must be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
Chunosovd6afedc2017-11-06 22:09:45 +0700125 * @note In case the input is not a multiple of VECTOR_SIZE (2,4,8,16) -DNON_MULTIPLE_OF_VECTOR_SIZE must be passed.
126 * @note Beta can be optionally passed at compile time using -DBETA (by default, it is 1.0).
127 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100128 * @param[in] src_ptr Pointer to the source tensor slice. Supported data types: F16/F32
Chunosovd6afedc2017-11-06 22:09:45 +0700129 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
130 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
131 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
132 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
133 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
134 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
135 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
136 * @param[in] maxo_ptr Pointer to the max values tensor slice. Supported data types: same as @p src_ptr
137 * @param[in] maxo_stride_x Stride of the max values tensor in X dimension (in bytes)
138 * @param[in] maxo_step_x max_stride_x * number of elements along X processed per workitem(in bytes)
139 * @param[in] maxo_stride_y Stride of the max values tensor in Y dimension (in bytes)
140 * @param[in] maxo_step_y max_stride_y * number of elements along Y processed per workitem(in bytes)
141 * @param[in] maxo_stride_z Stride of the max values tensor in Z dimension (in bytes)
142 * @param[in] maxo_step_z max_stride_z * number of elements along Z processed per workitem(in bytes)
143 * @param[in] maxo_offset_first_element_in_bytes The offset of the first element in the max values tensor
144 * @param[out] dst_ptr Pointer to the destination tensor slice. Supported data types: same as @p src_ptr
145 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
146 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
147 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
148 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
149 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
150 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
151 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
152 * @param[out] sum_ptr Pointer to the sum values tensor slice. Supported data types: same as @p src_ptr
153 * @param[in] sum_stride_x Stride of the sum values tensor in X dimension (in bytes)
154 * @param[in] sum_step_x sum_stride_x * number of elements along X processed per workitem(in bytes)
155 * @param[in] sum_stride_y Stride of the sum values tensor in Y dimension (in bytes)
156 * @param[in] sum_step_y sum_stride_z * number of elements along Z processed per workitem(in bytes)
157 * @param[in] sum_stride_z Stride of the sum values tensor in Z dimension (in bytes)
158 * @param[in] sum_step_z sum_stride_z * number of elements along Z processed per workitem(in bytes)
159 * @param[in] sum_offset_first_element_in_bytes The offset of the first element in the sum values tensor
160 * @param[in] width Input image width
161 */
162__kernel void softmax_layer_max_shift_exp_sum_serial(
163 TENSOR3D_DECLARATION(src),
164 TENSOR3D_DECLARATION(maxo),
165 TENSOR3D_DECLARATION(dst),
166 TENSOR3D_DECLARATION(sum),
167 uint width)
168{
169 Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
170 Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
171 Image maxo = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(maxo);
172 Image sum = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(sum);
173
174#ifdef BETA
175 // Initialize beta
176 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
Georgios Pinitas4df76c92017-11-10 10:26:11 +0000177 beta = (VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE))BETA;
Chunosovd6afedc2017-11-06 22:09:45 +0700178#endif /* BETA */
179
180 // Initialize local maximum
181 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
182 max_val_vec = (VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE))type_min_;
183
184 // Calculate max of row
185 const uint width_ = width >> LOG_VECTOR_SIZE;
186 for(uint i = 0; i < width_; i++)
187 {
188 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
189 data_max = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)offset(&src, i << LOG_VECTOR_SIZE, 0));
190 max_val_vec = MAX_OP(data_max, max_val_vec, DATA_TYPE, VECTOR_SIZE);
191 }
192
193#ifdef NON_MULTIPLE_OF_VECTOR_SIZE
194 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
195 data_max = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)offset(&src, width_ << LOG_VECTOR_SIZE, 0));
196 VEC_DATA_TYPE(SELECT_DATA_TYPE, VECTOR_SIZE)
197 widx = CONVERT((EXPAND((CL_VEC_DATA_TYPE(uint, VECTOR_SIZE)))(width_ << LOG_VECTOR_SIZE) + idx__) < width, VEC_DATA_TYPE(SELECT_DATA_TYPE, VECTOR_SIZE));
198 max_val_vec = MAX_OP(max_val_vec, select(type_min_, data_max, widx), DATA_TYPE, VECTOR_SIZE);
199#endif /* NON_MULTIPLE_OF_VECTOR_SIZE */
200
201 // Perform max reduction
202#if VECTOR_SIZE == 16
203 max_val_vec.s01234567 = MAX_OP(max_val_vec.s01234567, max_val_vec.s89ABCDEF, DATA_TYPE, 8);
204#endif /* VECTOR SIZE 16 END */
205#if VECTOR_SIZE >= 8
206 max_val_vec.s0123 = MAX_OP(max_val_vec.s0123, max_val_vec.s4567, DATA_TYPE, 4);
207#endif /* VECTOR SIZE 8 END */
208#if VECTOR_SIZE >= 4
209 max_val_vec.s01 = MAX_OP(max_val_vec.s01, max_val_vec.s23, DATA_TYPE, 2);
210#endif /* VECTOR SIZE 4 END */
211 max_val_vec.s0 = MAX_OP(max_val_vec.s0, max_val_vec.s1, DATA_TYPE, 1);
212 // Store result
213 *((__global DATA_TYPE *)maxo.ptr) = max_val_vec.s0;
214
215 /* Second section */
216
217 // Load max value of 1D logits vector (row)
218 DATA_TYPE max_val = *((__global DATA_TYPE *)offset(&maxo, 0, 0));
219
220 // Set sum vector
221 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
222 sum1D = 0;
223
224 // Shift values, exp and sum
225 for(uint i = 0; i < width_; i++)
226 {
227 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
228 data = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)offset(&src, i << LOG_VECTOR_SIZE, 0));
229 data = SUB_OP(data, max_val, DATA_TYPE, VECTOR_SIZE);
230#ifdef BETA
231 data = MUL_OP(data, beta, DATA_TYPE, VECTOR_SIZE);
232#endif /* BETA */
Sang-Hoon Park62eeb532019-10-29 13:13:19 +0000233#ifdef LOG_SOFTMAX
234 VSTORE(VECTOR_SIZE)
235 (data, 0, (__global DATA_TYPE *)offset(&dst, i << LOG_VECTOR_SIZE, 0));
236 data = EXP_OP(data, DATA_TYPE, VECTOR_SIZE);
237#else /* LOG_SOFTMAX */
Chunosovd6afedc2017-11-06 22:09:45 +0700238 data = EXP_OP(data, DATA_TYPE, VECTOR_SIZE);
239 VSTORE(VECTOR_SIZE)
240 (data, 0, (__global DATA_TYPE *)offset(&dst, i << LOG_VECTOR_SIZE, 0));
Sang-Hoon Park62eeb532019-10-29 13:13:19 +0000241#endif /* LOG_SOFTMAX */
Chunosovd6afedc2017-11-06 22:09:45 +0700242 sum1D = ADD_OP(sum1D, data, DATA_TYPE, VECTOR_SIZE);
243 }
244
245#ifdef NON_MULTIPLE_OF_VECTOR_SIZE
246 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
247 data = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)offset(&src, width_ << LOG_VECTOR_SIZE, 0));
248 data = SUB_OP(data, max_val, DATA_TYPE, VECTOR_SIZE);
249#ifdef BETA
250 data = MUL_OP(data, beta, DATA_TYPE, VECTOR_SIZE);
251#endif /* BETA */
Sang-Hoon Park62eeb532019-10-29 13:13:19 +0000252#ifdef LOG_SOFTMAX
253 VSTORE(VECTOR_SIZE)
254 (data, 0, (__global DATA_TYPE *)offset(&dst, width_ << LOG_VECTOR_SIZE, 0));
255 data = EXP_OP(data, DATA_TYPE, VECTOR_SIZE);
256 widx = CONVERT((EXPAND((CL_VEC_DATA_TYPE(uint, VECTOR_SIZE)))(width_ << LOG_VECTOR_SIZE) + idx__) < width, VEC_DATA_TYPE(SELECT_DATA_TYPE, VECTOR_SIZE));
257 data = select(0, data, widx);
258#else /* LOG_SOFTMAX */
Chunosovd6afedc2017-11-06 22:09:45 +0700259 data = EXP_OP(data, DATA_TYPE, VECTOR_SIZE);
260 widx = CONVERT((EXPAND((CL_VEC_DATA_TYPE(uint, VECTOR_SIZE)))(width_ << LOG_VECTOR_SIZE) + idx__) < width, VEC_DATA_TYPE(SELECT_DATA_TYPE, VECTOR_SIZE));
261 data = select(0, data, widx);
262 VSTORE(VECTOR_SIZE)
263 (data, 0, (__global DATA_TYPE *)offset(&dst, width_ << LOG_VECTOR_SIZE, 0));
Sang-Hoon Park62eeb532019-10-29 13:13:19 +0000264#endif /* LOG_SOFTMAX */
Chunosovd6afedc2017-11-06 22:09:45 +0700265 sum1D = ADD_OP(sum1D, data, DATA_TYPE, VECTOR_SIZE);
266#endif /* NON_MULTIPLE_OF_VECTOR_SIZE */
267
268 // Perform sum reduction
269#if VECTOR_SIZE == 16
270 sum1D.s01234567 = ADD_OP(sum1D.s01234567, sum1D.s89ABCDEF, DATA_TYPE, 8);
271#endif /* VECTOR SIZE 16 END */
272#if VECTOR_SIZE >= 8
273 sum1D.s0123 = ADD_OP(sum1D.s0123, sum1D.s4567, DATA_TYPE, 4);
274#endif /* VECTOR SIZE 8 END */
275#if VECTOR_SIZE >= 4
276 sum1D.s01 = ADD_OP(sum1D.s01, sum1D.s23, DATA_TYPE, 2);
277#endif /* VECTOR SIZE 4 END */
278 sum1D.s0 = ADD_OP(sum1D.s0, sum1D.s1, DATA_TYPE, 1);
279
280 // Calculate and store result
281 *((__global DATA_TYPE *)sum.ptr) = sum1D.s0;
282}
283
284/** Identifies the maximum value across the 1st dimension and shifts the values of the input tensor by this maximum value,
285 * then gets the exponent of each element as sums all elements across each row.
286 *
287 * @note Datatype must be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
Chunosovd6afedc2017-11-06 22:09:45 +0700288 * @note In case the input is not a multiple of VECTOR_SIZE (2,4,8,16) -DNON_MULTIPLE_OF_VECTOR_SIZE must be passed.
289 * @note Beta can be optionally passed at compile time using -DBETA (by default, it is 1.0).
290 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100291 * @param[in] src_ptr Pointer to the source tensor slice. Supported data types: F16/F32
Chunosovd6afedc2017-11-06 22:09:45 +0700292 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
293 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
294 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
295 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
296 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
297 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
298 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
299 * @param[in] maxo_ptr Pointer to the max values tensor slice. Supported data types: same as @p src_ptr
300 * @param[in] maxo_stride_x Stride of the max values tensor in X dimension (in bytes)
301 * @param[in] maxo_step_x max_stride_x * number of elements along X processed per workitem(in bytes)
302 * @param[in] maxo_stride_y Stride of the max values tensor in Y dimension (in bytes)
303 * @param[in] maxo_step_y max_stride_y * number of elements along Y processed per workitem(in bytes)
304 * @param[in] maxo_stride_z Stride of the max values tensor in Z dimension (in bytes)
305 * @param[in] maxo_step_z max_stride_z * number of elements along Z processed per workitem(in bytes)
306 * @param[in] maxo_offset_first_element_in_bytes The offset of the first element in the max values tensor
307 * @param[out] dst_ptr Pointer to the destination tensor slice. Supported data types: same as @p src_ptr
308 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
309 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
310 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
311 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
312 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
313 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
314 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
315 * @param[out] sum_ptr Pointer to the sum values tensor slice. Supported data types: same as @p src_ptr
316 * @param[in] sum_stride_x Stride of the sum values tensor in X dimension (in bytes)
317 * @param[in] sum_step_x sum_stride_x * number of elements along X processed per workitem(in bytes)
318 * @param[in] sum_stride_y Stride of the sum values tensor in Y dimension (in bytes)
319 * @param[in] sum_step_y sum_stride_z * number of elements along Z processed per workitem(in bytes)
320 * @param[in] sum_stride_z Stride of the sum values tensor in Z dimension (in bytes)
321 * @param[in] sum_step_z sum_stride_z * number of elements along Z processed per workitem(in bytes)
322 * @param[in] sum_offset_first_element_in_bytes The offset of the first element in the sum values tensor
323 * @param[in] width Input image width
324 */
325__kernel void softmax_layer_max_shift_exp_sum_parallel(
326 TENSOR3D_DECLARATION(src),
327 TENSOR3D_DECLARATION(maxo),
328 TENSOR3D_DECLARATION(dst),
329 TENSOR3D_DECLARATION(sum),
330 uint width)
331{
332 Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
333 Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
334 Image maxo = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(maxo);
335 Image sum = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(sum);
336
337 const uint lid = get_local_id(0);
338
339#ifdef BETA
340 // Initialize beta
341 VEC_DATA_TYPE(DATA_TYPE, 4)
342 beta = (VEC_DATA_TYPE(DATA_TYPE, 4))BETA;
343#endif /* BETA */
344
345 // Define one temporary vector per work-item.
346 __local VEC_DATA_TYPE(DATA_TYPE, 4) tmp_local[GRID_SIZE];
347 __local DATA_TYPE max_local;
348
349 __constant VEC_DATA_TYPE(DATA_TYPE, 4) type_min4 = (VEC_DATA_TYPE(DATA_TYPE, 4))(MINVAL);
350 VEC_DATA_TYPE(DATA_TYPE, 4)
351 max_val_vec = (VEC_DATA_TYPE(DATA_TYPE, 4))type_min4;
352 // Number of elements per work-item.
353 const uint row = width / GRID_SIZE;
354 // Number of iterations per work-item.
355 const uint width_ = row >> 2;
356 // Calculate max of row
357 uint i = 0;
358 for(; i < width_; i++)
359 {
360 VEC_DATA_TYPE(DATA_TYPE, 4)
361 data_max = VLOAD(4)(0, (__global DATA_TYPE *)offset(&src, i * GRID_SIZE * 4, 0));
362 max_val_vec = MAX_OP(data_max, max_val_vec, DATA_TYPE, 4);
363 }
364#ifdef NON_MULTIPLE_OF_GRID_SIZE
365 // How many work-items needed to complete the computation.
366 //TODO: Optimize this calculation (avoid %).
367 int boundary_workitems = (width % (GRID_SIZE * 4)) / 4;
368 if(lid < boundary_workitems)
369 {
370 VEC_DATA_TYPE(DATA_TYPE, 4)
371 data_max = VLOAD(4)(0, (__global DATA_TYPE *)offset(&src, i * GRID_SIZE * 4, 0));
372 max_val_vec = MAX_OP(data_max, max_val_vec, DATA_TYPE, 4);
373 }
374#ifdef NON_MULTIPLE_OF_VECTOR_SIZE
375 if(boundary_workitems == 0)
376 {
377 boundary_workitems = GRID_SIZE;
378 i--;
379 }
380 if(lid == (boundary_workitems - 1))
381 {
382 // Handle non multiple of 4
383 VEC_DATA_TYPE(DATA_TYPE, 4)
384 data_max = VLOAD(4)(0, (__global DATA_TYPE *)offset(&src, (GRID_SIZE * i * 4) + 4, 0));
385 VEC_DATA_TYPE(SELECT_DATA_TYPE, 4)
386 widx = CONVERT(((uint4)(GRID_SIZE * i * 4) + boundary_workitems * 4 + idx4) < width, VEC_DATA_TYPE(SELECT_DATA_TYPE, 4));
387 max_val_vec = MAX_OP(max_val_vec, select(type_min_, data_max, widx), DATA_TYPE, 4);
388 }
389#endif /* NON_MULTIPLE_OF_VECTOR_SIZE */
390#endif /* NON_MULTIPLE_OF_GRID_SIZE */
391 tmp_local[lid] = max_val_vec;
392
393 barrier(CLK_LOCAL_MEM_FENCE);
394
395 if(GRID_SIZE >= 256)
396 {
397 if(lid < 128)
398 {
399 tmp_local[lid] = MAX_OP(tmp_local[lid + 128], tmp_local[lid], DATA_TYPE, 4);
400 }
401 barrier(CLK_LOCAL_MEM_FENCE);
402 }
403 if(GRID_SIZE >= 128)
404 {
405 if(lid < 64)
406 {
407 tmp_local[lid] = MAX_OP(tmp_local[lid + 64], tmp_local[lid], DATA_TYPE, 4);
408 }
409 barrier(CLK_LOCAL_MEM_FENCE);
410 }
411 if(GRID_SIZE >= 64)
412 {
413 if(lid < 32)
414 {
415 tmp_local[lid] = MAX_OP(tmp_local[lid + 32], tmp_local[lid], DATA_TYPE, 4);
416 }
417 barrier(CLK_LOCAL_MEM_FENCE);
418 }
419 if(GRID_SIZE >= 32)
420 {
421 if(lid < 16)
422 {
423 tmp_local[lid] = MAX_OP(tmp_local[lid + 16], tmp_local[lid], DATA_TYPE, 4);
424 }
425 barrier(CLK_LOCAL_MEM_FENCE);
426 }
427 if(GRID_SIZE >= 16)
428 {
429 if(lid < 8)
430 {
431 tmp_local[lid] = MAX_OP(tmp_local[lid + 8], tmp_local[lid], DATA_TYPE, 4);
432 }
433 barrier(CLK_LOCAL_MEM_FENCE);
434 }
435 if(GRID_SIZE >= 8)
436 {
437 if(lid < 4)
438 {
439 tmp_local[lid] = MAX_OP(tmp_local[lid + 4], tmp_local[lid], DATA_TYPE, 4);
440 }
441 barrier(CLK_LOCAL_MEM_FENCE);
442 }
443 if(GRID_SIZE >= 4)
444 {
445 if(lid < 2)
446 {
447 tmp_local[lid] = MAX_OP(tmp_local[lid + 2], tmp_local[lid], DATA_TYPE, 4);
448 }
449 barrier(CLK_LOCAL_MEM_FENCE);
450 }
451 if(lid == 0)
452 {
453 max_val_vec = MAX_OP(tmp_local[lid + 1], tmp_local[lid], DATA_TYPE, 4);
454 max_val_vec.s01 = MAX_OP(max_val_vec.s01, max_val_vec.s23, DATA_TYPE, 2);
455 max_val_vec.s0 = MAX_OP(max_val_vec.s0, max_val_vec.s1, DATA_TYPE, 1);
456 max_local = max_val_vec.s0;
457 }
458 barrier(CLK_LOCAL_MEM_FENCE);
459
460 /* Second section */
461
462 // Set sum vector
463 VEC_DATA_TYPE(DATA_TYPE, 4)
464 sum1D = 0;
465 DATA_TYPE max_val = max_local;
466
467 // Shift values, exp and sum
468 for(i = 0; i < width_; i++)
469 {
470 VEC_DATA_TYPE(DATA_TYPE, 4)
471 data = VLOAD(4)(0, (__global DATA_TYPE *)offset(&src, i * GRID_SIZE * 4, 0));
472 data = SUB_OP(data, max_val, DATA_TYPE, 4);
473#ifdef BETA
474 data = MUL_OP(data, beta, DATA_TYPE, 4);
475#endif /* BETA */
Sang-Hoon Park62eeb532019-10-29 13:13:19 +0000476#ifdef LOG_SOFTMAX
477 VSTORE(4)
478 (data, 0, (__global DATA_TYPE *)offset(&dst, i * GRID_SIZE * 4, 0));
479 data = EXP_OP(data, DATA_TYPE, 4);
480#else /* LOG_SOFTMAX */
Chunosovd6afedc2017-11-06 22:09:45 +0700481 data = EXP_OP(data, DATA_TYPE, 4);
482 VSTORE(4)
483 (data, 0, (__global DATA_TYPE *)offset(&dst, i * GRID_SIZE * 4, 0));
Sang-Hoon Park62eeb532019-10-29 13:13:19 +0000484#endif /* LOG_SOFTMAX */
Chunosovd6afedc2017-11-06 22:09:45 +0700485 sum1D = ADD_OP(sum1D, data, DATA_TYPE, 4);
486 }
487#ifdef NON_MULTIPLE_OF_GRID_SIZE
488 //TODO: Optimize the calculation (avoid %).
489 boundary_workitems = (width % (GRID_SIZE * 4)) / 4;
490 if(lid < boundary_workitems)
491 {
492 VEC_DATA_TYPE(DATA_TYPE, 4)
493 data = VLOAD(4)(0, (__global DATA_TYPE *)offset(&src, i * GRID_SIZE * 4, 0));
494 data = SUB_OP(data, max_val, DATA_TYPE, 4);
495#ifdef BETA
496 data = MUL_OP(data, beta, DATA_TYPE, 4);
497#endif /* BETA */
Sang-Hoon Park62eeb532019-10-29 13:13:19 +0000498#ifdef LOG_SOFTMAX
499 VSTORE(4)
500 (data, 0, (__global DATA_TYPE *)offset(&dst, i * GRID_SIZE * 4, 0));
501 data = EXP_OP(data, DATA_TYPE, 4);
502#else /* LOG_SOFTMAX */
Chunosovd6afedc2017-11-06 22:09:45 +0700503 data = EXP_OP(data, DATA_TYPE, 4);
504 VSTORE(4)
505 (data, 0, (__global DATA_TYPE *)offset(&dst, i * GRID_SIZE * 4, 0));
Sang-Hoon Park62eeb532019-10-29 13:13:19 +0000506#endif /* LOG_SOFTMAX */
Chunosovd6afedc2017-11-06 22:09:45 +0700507 sum1D = ADD_OP(sum1D, data, DATA_TYPE, 4);
508 }
509#ifdef NON_MULTIPLE_OF_VECTOR_SIZE
510 if(boundary_workitems == 0)
511 {
512 boundary_workitems = GRID_SIZE;
513 i--;
514 }
515 if(lid == (boundary_workitems - 1))
516 {
517 // Handle non multiple of vector size ((GRID_SIZE * i * 4) + 4, 0); move 4 float positions ahead, *4 is due to the stride
518 VEC_DATA_TYPE(DATA_TYPE, 4)
519 data = VLOAD(4)(0, (__global DATA_TYPE *)offset(&src, (GRID_SIZE * i * 4) + 4, 0));
520 data = SUB_OP(data, max_val, DATA_TYPE, 4);
521#ifdef BETA
522 data = MUL_OP(data, beta, DATA_TYPE, 4);
523#endif /* BETA */
Sang-Hoon Park62eeb532019-10-29 13:13:19 +0000524#ifdef LOG_SOFTMAX
525 VSTORE(4)
526 (data, 0, (__global DATA_TYPE *)offset(&dst, (GRID_SIZE * i * 4) + 4, 0));
527 data = EXP_OP(data, DATA_TYPE, 4);
528 VEC_DATA_TYPE(SELECT_DATA_TYPE, 4)
529 widx = CONVERT(((uint4)(GRID_SIZE * i * 4) + boundary_workitems * 4 + idx4) < width, VEC_DATA_TYPE(SELECT_DATA_TYPE, 4));
530 data = select(0, data, widx);
531#else /* LOG_SOFTMAX */
Chunosovd6afedc2017-11-06 22:09:45 +0700532 data = EXP_OP(data, DATA_TYPE, 4);
533 VEC_DATA_TYPE(SELECT_DATA_TYPE, 4)
534 widx = CONVERT(((uint4)(GRID_SIZE * i * 4) + boundary_workitems * 4 + idx4) < width, VEC_DATA_TYPE(SELECT_DATA_TYPE, 4));
535 data = select(0, data, widx);
536 VSTORE(4)
537 (data, 0, (__global DATA_TYPE *)offset(&dst, (GRID_SIZE * i * 4) + 4, 0));
Sang-Hoon Park62eeb532019-10-29 13:13:19 +0000538#endif /* LOG_SOFTMAX */
Chunosovd6afedc2017-11-06 22:09:45 +0700539 sum1D = ADD_OP(sum1D, data, DATA_TYPE, 4);
540 }
541#endif /* NON_MULTIPLE_OF_VECTOR_SIZE */
542#endif /* NON_MULTIPLE_OF_GRID_SIZE */
543 tmp_local[lid] = sum1D;
544
545 barrier(CLK_LOCAL_MEM_FENCE);
546
547 if(GRID_SIZE >= 256)
548 {
549 if(lid < 128)
550 {
551 tmp_local[lid] = ADD_OP(tmp_local[lid + 128], tmp_local[lid], DATA_TYPE, 4);
552 }
553 barrier(CLK_LOCAL_MEM_FENCE);
554 }
555 if(GRID_SIZE >= 128)
556 {
557 if(lid < 64)
558 {
559 tmp_local[lid] = ADD_OP(tmp_local[lid + 64], tmp_local[lid], DATA_TYPE, 4);
560 }
561 barrier(CLK_LOCAL_MEM_FENCE);
562 }
563 if(GRID_SIZE >= 64)
564 {
565 if(lid < 32)
566 {
567 tmp_local[lid] = ADD_OP(tmp_local[lid + 32], tmp_local[lid], DATA_TYPE, 4);
568 }
569 barrier(CLK_LOCAL_MEM_FENCE);
570 }
571 if(GRID_SIZE >= 32)
572 {
573 if(lid < 16)
574 {
575 tmp_local[lid] = ADD_OP(tmp_local[lid + 16], tmp_local[lid], DATA_TYPE, 4);
576 }
577 barrier(CLK_LOCAL_MEM_FENCE);
578 }
579 if(GRID_SIZE >= 16)
580 {
581 if(lid < 8)
582 {
583 tmp_local[lid] = ADD_OP(tmp_local[lid + 8], tmp_local[lid], DATA_TYPE, 4);
584 }
585 barrier(CLK_LOCAL_MEM_FENCE);
586 }
587 if(GRID_SIZE >= 8)
588 {
589 if(lid < 4)
590 {
591 tmp_local[lid] = ADD_OP(tmp_local[lid + 4], tmp_local[lid], DATA_TYPE, 4);
592 }
593 barrier(CLK_LOCAL_MEM_FENCE);
594 }
595 if(GRID_SIZE >= 4)
596 {
597 if(lid < 2)
598 {
599 tmp_local[lid] = ADD_OP(tmp_local[lid + 2], tmp_local[lid], DATA_TYPE, 4);
600 }
601 barrier(CLK_LOCAL_MEM_FENCE);
602 }
603 if(lid == 0)
604 {
605 sum1D = ADD_OP(tmp_local[lid + 1], tmp_local[lid], DATA_TYPE, 4);
606 // Perform max reduction
607 sum1D.s01 = ADD_OP(sum1D.s01, sum1D.s23, DATA_TYPE, 2);
608 sum1D.s0 = ADD_OP(sum1D.s0, sum1D.s1, DATA_TYPE, 1);
609 *((__global DATA_TYPE *)sum.ptr) = sum1D.s0;
610 }
611}