blob: bfc0995bb8997de15ae6d6442b7acb828cffc754 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Viet-Hoa Do29254ae2023-10-13 17:40:32 +01002 * Copyright (c) 2017-2021, 2023 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Viet-Hoa Do29254ae2023-10-13 17:40:32 +010024
Anthony Barbier6ff3b192017-09-04 18:44:23 +010025#include "helpers.h"
26
Viet-Hoa Do29254ae2023-10-13 17:40:32 +010027#define MIN_VALUE_float -FLT_MAX
28#define MIN_VALUE_half -HALF_MAX
29#define MIN_VALUE_char CHAR_MIN
30#define MIN_VALUE_uchar 0
Anthony Barbier6ff3b192017-09-04 18:44:23 +010031
Viet-Hoa Do29254ae2023-10-13 17:40:32 +010032#define MIN_VALUE_TYPE_STR(data_type) MIN_VALUE_##data_type
33#define MIN_VALUE_TYPE(data_type) MIN_VALUE_TYPE_STR(data_type)
34#define MIN_VALUE MIN_VALUE_TYPE(DATA_TYPE)
35
36#ifdef SOFTMAX_X
37
38/** 3-pass softmax in the x dimension.
Anthony Barbier6ff3b192017-09-04 18:44:23 +010039 *
Viet-Hoa Do29254ae2023-10-13 17:40:32 +010040 * List of preprocessors:
41 * - DATA_TYPE: the input/output data type.
42 * - TMP_DATA_TYPE: the data type used for computing and temporary tensor storage.
43 * If DATA_TYPE is quantized, TMP_DATA_TYPE is floating-point, otherwise TMP_DATA_TYPE is the same as DATA_TYPE.
44 * - IS_LOG (optional): indicating whether this is log softmax.
45 * - LENGTH: the number of elements in softmax axis in the input/output tensors.
46 * - BETA: the beta coefficient.
47 * - IS_QUANTIZED (optional): indicating whether the input/output data type is quantized data.
48 * - VEC_SIZE: the size of the vector.
Anthony Barbier6ff3b192017-09-04 18:44:23 +010049 *
Viet-Hoa Do29254ae2023-10-13 17:40:32 +010050 * Additional preprocessors in case IS_QUANTIZED is present:
51 * - SRC_SCALE and SRC_OFFSET: the quantization information of the source tensor.
52 * - DST_SCALE and DST_OFFSET: the quantization information of the destination tensor.
53 *
54 * @param[in] src_ptr Pointer to the source tensor.
55 * @param[in] src_stride_0 Stride in bytes of the source tensor in the dimension corresponding to global ID 0.
56 * @param[in] src_stride_1 Stride in bytes of the source tensor in the dimension corresponding to global ID 1.
57 * @param[in] src_stride_2 Stride in bytes of the source tensor in the dimension corresponding to global ID 2.
58 * @param[in] src_offset_first_element Offset of the first element in the source tensor.
59 * @param[in] dst_ptr Pointer to the destination tensor.
60 * @param[in] dst_stride_0 Stride in bytes of the destination tensor in the dimension corresponding to global ID 0.
61 * @param[in] dst_stride_1 Stride in bytes of the destination tensor in the dimension corresponding to global ID 1.
62 * @param[in] dst_stride_2 Stride in bytes of the destination tensor in the dimension corresponding to global ID 2.
63 * @param[in] dst_offset_first_element Offset of the first element in the destination tensor.
64 * @param[in] tmp_ptr Pointer to the temporary tensor.
65 * @param[in] tmp_stride_0 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 0.
66 * @param[in] tmp_stride_1 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 1.
67 * @param[in] tmp_stride_2 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 2.
68 * @param[in] tmp_offset_first_element Offset of the first element in the temporary tensor.
Anthony Barbier6ff3b192017-09-04 18:44:23 +010069 */
Viet-Hoa Do29254ae2023-10-13 17:40:32 +010070__kernel void softmax_x(
71 __global uchar *src_ptr,
72 uint src_stride_0,
73 uint src_stride_1,
74 uint src_stride_2,
75 uint src_offset_first_element,
76
77 __global uchar *dst_ptr,
78 uint dst_stride_0,
79 uint dst_stride_1,
80 uint dst_stride_2,
81 uint dst_offset_first_element
82
83#ifdef IS_QUANTIZED
84 ,
85 __global uchar *tmp_ptr,
86 uint tmp_stride_0,
87 uint tmp_stride_1,
88 uint tmp_stride_2,
89 uint tmp_offset_first_element
90#endif // IS_QUANTIZED
91)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010092{
Viet-Hoa Do29254ae2023-10-13 17:40:32 +010093 const int dim_0 = get_global_id(0);
94 const int dim_1 = get_global_id(1);
95 const int dim_2 = get_global_id(2);
Giorgio Arena2d1a8352020-10-26 15:04:08 +000096
Viet-Hoa Do29254ae2023-10-13 17:40:32 +010097 src_ptr += src_offset_first_element + dim_2 * src_stride_2 + dim_1 * src_stride_1 + dim_0 * src_stride_0;
98 dst_ptr += dst_offset_first_element + dim_2 * dst_stride_2 + dim_1 * dst_stride_1 + dim_0 * dst_stride_0;
Giorgio Arena2d1a8352020-10-26 15:04:08 +000099
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100100#ifdef IS_QUANTIZED
101 tmp_ptr += tmp_offset_first_element + dim_2 * tmp_stride_2 + dim_1 * tmp_stride_1 + dim_0 * tmp_stride_0;
102#else // IS_QUANTIZED
103 __global uchar *tmp_ptr = dst_ptr;
104#endif // IS_QUANTIZED
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100105
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100106 // Calculate max value.
107 DATA_TYPE max_value = MIN_VALUE;
108 int i = 0;
Giorgio Arena2d1a8352020-10-26 15:04:08 +0000109
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100110 for (i = 0; i < LENGTH - VEC_SIZE; i += VEC_SIZE)
111 {
112 VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE)));
Giorgio Arena2d1a8352020-10-26 15:04:08 +0000113
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100114 max_value = max(max_value, MAX_REDUCE(data, VEC_SIZE));
115 }
116
117 for (; i < LENGTH; ++i)
118 {
119 DATA_TYPE data = *(__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE));
120
121 max_value = max(max_value, data);
122 }
123
124 // Regularize the data.
125 TMP_DATA_TYPE sum_value = 0;
126
127#ifdef IS_QUANTIZED
128 TMP_DATA_TYPE max_value_f = (CONVERT(max_value, TMP_DATA_TYPE) - SRC_OFFSET) * SRC_SCALE;
129 TMP_DATA_TYPE regularize_offset = -SRC_OFFSET * SRC_SCALE * (TMP_DATA_TYPE)BETA - max_value_f * (TMP_DATA_TYPE)BETA;
130# define REGULARIZE(x) ((x) * SRC_SCALE * (TMP_DATA_TYPE)BETA + regularize_offset)
131#else // IS_QUANTIZED
132# define REGULARIZE(x) (((x) - max_value) * (TMP_DATA_TYPE)BETA)
133#endif // IS_QUANTIZED
134
135 for (i = 0; i < LENGTH - VEC_SIZE; i += VEC_SIZE)
136 {
137 VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = CONVERT(VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE))), VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE));
138
139 data = REGULARIZE(data);
140
141#ifdef IS_LOG
142 sum_value += SUM_REDUCE(exp(data), VEC_SIZE);
143#else // IS_LOG
144 data = exp(data);
145 sum_value += SUM_REDUCE(data, VEC_SIZE);
146#endif // IS_LOG
147
148 VSTORE(VEC_SIZE)(data, 0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE)));
149 }
150
151 for (; i < LENGTH; ++i)
152 {
153 TMP_DATA_TYPE data = CONVERT(*(__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE)), TMP_DATA_TYPE);
154
155 data = REGULARIZE(data);
156
157#ifdef IS_LOG
158 sum_value += exp(data);
159#else // IS_LOG
160 data = exp(data);
161 sum_value += data;
162#endif // IS_LOG
163
164 *(__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE)) = data;
165 }
166
167#undef REGULARIZE
168
169 // Normalize the data.
170#ifdef IS_QUANTIZED
171# if IS_LOG
172 TMP_DATA_TYPE norm_offset = -log(sum_value) + DST_OFFSET;
173# define NORMALIZE(SIZE, x) CONVERT_SAT_ROUND((x) / DST_SCALE + norm_offset, VEC_DATA_TYPE(DATA_TYPE, SIZE), rte)
174# else // IS_LOG
175 TMP_DATA_TYPE norm_div = sum_value * DST_SCALE;
176# define NORMALIZE(SIZE, x) CONVERT_SAT(add_sat(CONVERT_SAT_ROUND((x) / norm_div, VEC_DATA_TYPE(int, SIZE), rte), DST_OFFSET), VEC_DATA_TYPE(DATA_TYPE, SIZE))
177# endif // IS_LOG
178#else // IS_QUANTIZED
179# if IS_LOG
180# define NORMALIZE(SIZE, x) ((x) - log(sum_value))
181# else // IS_LOG
182# define NORMALIZE(SIZE, x) ((x) / sum_value)
183# endif // IS_LOG
184#endif // IS_QUANTIZED
185
186 for (i = 0; i < LENGTH - VEC_SIZE; i += VEC_SIZE)
187 {
188 VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE)));
189
190 VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) result = NORMALIZE(VEC_SIZE, data);
191
192 VSTORE(VEC_SIZE)(result, 0, (__global DATA_TYPE *)(dst_ptr + i * sizeof(DATA_TYPE)));
193 }
194
195 for (; i < LENGTH; ++i)
196 {
197 TMP_DATA_TYPE data = *(__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE));
198
199 DATA_TYPE result = NORMALIZE(1, data);
200
201 *(__global DATA_TYPE *)(dst_ptr + i * sizeof(DATA_TYPE)) = result;
202 }
203
204#undef NORMALIZE
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100205}
Chunosovd6afedc2017-11-06 22:09:45 +0700206
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100207#endif // SOFTMAX_X
Giorgio Arena2d1a8352020-10-26 15:04:08 +0000208
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100209#ifdef SOFTMAX_NON_X
Giorgio Arena2d1a8352020-10-26 15:04:08 +0000210
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100211/** 3-pass softmax in any dimension higher than the x dimension.
Chunosovd6afedc2017-11-06 22:09:45 +0700212 *
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100213 * List of preprocessors:
214 * - DATA_TYPE: the input/output data type.
215 * - TMP_DATA_TYPE: the data type used for computing and temporary tensor storage.
216 * If DATA_TYPE is quantized, TMP_DATA_TYPE is floating-point, otherwise TMP_DATA_TYPE is the same as DATA_TYPE.
217 * - IS_LOG (optional): indicating whether this is log softmax.
218 * - LENGTH: the number of elements in softmax axis in the input/output tensors.
219 * - BETA: the beta coefficient.
220 * - IS_QUANTIZED (optional): indicating whether the input/output data type is quantized data.
221 * - VEC_SIZE: the size of the vector.
222 * - VEC_SIZE_LEFTOVER: the size of the leftover part.
Chunosovd6afedc2017-11-06 22:09:45 +0700223 *
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100224 * Additional preprocessors in case IS_QUANTIZED is present:
225 * - SRC_SCALE and SRC_OFFSET: the quantization information of the source tensor.
226 * - DST_SCALE and DST_OFFSET: the quantization information of the destination tensor.
227 *
228 * @param[in] src_ptr Pointer to the source tensor.
229 * @param[in] src_stride_0 Stride in bytes of the source tensor in the dimension corresponding to global ID 0.
230 * @param[in] src_stride_1 Stride in bytes of the source tensor in the dimension corresponding to global ID 1.
231 * @param[in] src_stride_2 Stride in bytes of the source tensor in the dimension corresponding to global ID 2.
232 * @param[in] src_offset_first_element Offset of the first element in the source tensor.
233 * @param[in] dst_ptr Pointer to the destination tensor.
234 * @param[in] dst_stride_0 Stride in bytes of the destination tensor in the dimension corresponding to global ID 0.
235 * @param[in] dst_stride_1 Stride in bytes of the destination tensor in the dimension corresponding to global ID 1.
236 * @param[in] dst_stride_2 Stride in bytes of the destination tensor in the dimension corresponding to global ID 2.
237 * @param[in] dst_offset_first_element Offset of the first element in the destination tensor.
238 * @param[in] tmp_ptr Pointer to the temporary tensor.
239 * @param[in] tmp_stride_0 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 0.
240 * @param[in] tmp_stride_1 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 1.
241 * @param[in] tmp_stride_2 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 2.
242 * @param[in] tmp_offset_first_element Offset of the first element in the temporary tensor.
Chunosovd6afedc2017-11-06 22:09:45 +0700243 */
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100244__kernel void softmax_non_x(
245 __global uchar *src_ptr,
246 uint src_stride_0,
247 uint src_stride_1,
248 uint src_stride_2,
249 uint src_offset_first_element,
250
251 __global uchar *dst_ptr,
252 uint dst_stride_0,
253 uint dst_stride_1,
254 uint dst_stride_2,
255 uint dst_offset_first_element,
256
257 __global uchar *tmp_ptr,
258 uint tmp_stride_0,
259 uint tmp_stride_1,
260 uint tmp_stride_2,
261 uint tmp_offset_first_element,
262
263 uint src_stride_axis,
264 uint dst_stride_axis
265)
Chunosovd6afedc2017-11-06 22:09:45 +0700266{
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100267 const int dim_0 = max((int)get_global_id(0) * VEC_SIZE - (VEC_SIZE - VEC_SIZE_LEFTOVER) % VEC_SIZE, 0);
268 const int dim_1 = get_global_id(1);
269 const int dim_2 = get_global_id(2);
Giorgio Arena2d1a8352020-10-26 15:04:08 +0000270
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100271 src_ptr += src_offset_first_element + dim_2 * src_stride_2 + dim_1 * src_stride_1 + dim_0 * src_stride_0;
272 dst_ptr += dst_offset_first_element + dim_2 * dst_stride_2 + dim_1 * dst_stride_1 + dim_0 * dst_stride_0;
273 tmp_ptr += tmp_offset_first_element + dim_2 * tmp_stride_2 + dim_1 * tmp_stride_1 + dim_0 * tmp_stride_0;
Chunosovd6afedc2017-11-06 22:09:45 +0700274
Viet-Hoa Doec2afd62023-11-14 16:23:14 +0000275 // In case of processing quantized data, i.e. DATA_TYPE is smaller than TMP_DATA_TYPE:
276 //
277 // In the first pass (finding max), the quantized data is copied from the input tensor to the temporary tensor.
278 // Dequantization is not needed to find the max value and since dequantization widens the data, we defer it
279 // to the second pass pass to reduce memory bandwidth of the first pass.
280 //
281 // In the second pass, it reads the quantized data from the temporary tensor and writes the dequantized data
282 // back to the temporary tensor.
283 //
284 // To avoid dequantized data overwritting the unprocessed quantized data in the temporary tensor,
285 // this extra offset is introduced to store the quantized data at the end of the temporary tensor.
286 //
287 // Note: Another approach is to perform the second pass in reverse order, but for unexplanable reason
288 // it doesn't work in some devices.
289 uint tmp_extra_offset = LENGTH * VEC_SIZE * (sizeof(TMP_DATA_TYPE) - sizeof(DATA_TYPE));
290
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100291 // Calculate max value and store the input data to the temporary tensor in suitable format.
292 VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) max_value = MIN_VALUE;
293 int i = 0;
Chunosovd6afedc2017-11-06 22:09:45 +0700294
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100295 for (i = 0; i < LENGTH; ++i)
Giorgio Arena2d1a8352020-10-26 15:04:08 +0000296 {
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100297 VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_ptr + i * src_stride_axis));
298
299 max_value = max(max_value, data);
300
Viet-Hoa Doec2afd62023-11-14 16:23:14 +0000301 VSTORE(VEC_SIZE)(data, 0, (__global DATA_TYPE *)(tmp_ptr + tmp_extra_offset + i * VEC_SIZE * sizeof(DATA_TYPE)));
Giorgio Arena2d1a8352020-10-26 15:04:08 +0000302 }
303
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100304 // Regularize the data.
305 VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) sum_value = 0;
Chunosovd6afedc2017-11-06 22:09:45 +0700306
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100307#ifdef IS_QUANTIZED
308 VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) max_value_f = (CONVERT(max_value, VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE)) - SRC_OFFSET) * SRC_SCALE;
309 VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) regularize_offset = -SRC_OFFSET * SRC_SCALE * (TMP_DATA_TYPE)BETA - max_value_f * (TMP_DATA_TYPE)BETA;
310# define REGULARIZE(x) ((x) * SRC_SCALE * (TMP_DATA_TYPE)BETA + regularize_offset)
311#else // IS_QUANTIZED
312# define REGULARIZE(x) (((x) - max_value) * (TMP_DATA_TYPE)BETA)
313#endif // IS_QUANTIZED
Chunosovd6afedc2017-11-06 22:09:45 +0700314
Viet-Hoa Doec2afd62023-11-14 16:23:14 +0000315 for (i = 0; i < LENGTH; ++i)
Giorgio Arena2d1a8352020-10-26 15:04:08 +0000316 {
Viet-Hoa Doec2afd62023-11-14 16:23:14 +0000317 VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = CONVERT(VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(tmp_ptr + tmp_extra_offset + i * VEC_SIZE * sizeof(DATA_TYPE))), VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE));
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100318
319 data = REGULARIZE(data);
320
321#ifdef IS_LOG
322 sum_value += exp(data);
323#else // IS_LOG
Giorgio Arena2d1a8352020-10-26 15:04:08 +0000324 data = exp(data);
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100325 sum_value += data;
326#endif // IS_LOG
327
328 VSTORE(VEC_SIZE)(data, 0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * VEC_SIZE * sizeof(TMP_DATA_TYPE)));
Giorgio Arena2d1a8352020-10-26 15:04:08 +0000329 }
Chunosovd6afedc2017-11-06 22:09:45 +0700330
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100331#undef REGULARIZE
332
333 // Normalize the data.
334#ifdef IS_QUANTIZED
335# if IS_LOG
336 VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) norm_offset = -log(sum_value) + DST_OFFSET;
337# define NORMALIZE(x) CONVERT_SAT_ROUND((x) / DST_SCALE + norm_offset, VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE), rte)
338# else // IS_LOG
339 VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) norm_div = sum_value * DST_SCALE;
340# define NORMALIZE(x) CONVERT_SAT(add_sat(CONVERT_SAT_ROUND((x) / norm_div, VEC_DATA_TYPE(int, VEC_SIZE), rte), DST_OFFSET), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))
341# endif // IS_LOG
342#else // IS_QUANTIZED
343# if IS_LOG
344# define NORMALIZE(x) ((x) - log(sum_value))
345# else // IS_LOG
346# define NORMALIZE(x) ((x) / sum_value)
347# endif // IS_LOG
348#endif // IS_QUANTIZED
349
350 for (i = 0; i < LENGTH; ++i)
351 {
352 VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * VEC_SIZE * sizeof(TMP_DATA_TYPE)));
353
354 VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) result0 = NORMALIZE(data);
355
356 STORE_VECTOR_SELECT(result, DATA_TYPE, dst_ptr + i * dst_stride_axis, VEC_SIZE, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0)
357 }
358
359#undef NORMALIZE
Chunosovd6afedc2017-11-06 22:09:45 +0700360}
361
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100362#endif // SOFTMAX_NON_X
Giorgio Arena2d1a8352020-10-26 15:04:08 +0000363
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100364#undef MIN_VALUE
365#undef MIN_VALUE_TYPE
366#undef MIN_VALUE_TYPE_STR
Giorgio Arena2d1a8352020-10-26 15:04:08 +0000367
Viet-Hoa Do29254ae2023-10-13 17:40:32 +0100368#undef MIN_VALUE_float
369#undef MIN_VALUE_half
370#undef MIN_VALUE_char
371#undef MIN_VALUE_uchar