blob: bfc0995bb8997de15ae6d6442b7acb828cffc754 [file] [log] [blame]
/*
* Copyright (c) 2017-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#include "helpers.h"
#define MIN_VALUE_float -FLT_MAX
#define MIN_VALUE_half -HALF_MAX
#define MIN_VALUE_char CHAR_MIN
#define MIN_VALUE_uchar 0
#define MIN_VALUE_TYPE_STR(data_type) MIN_VALUE_##data_type
#define MIN_VALUE_TYPE(data_type) MIN_VALUE_TYPE_STR(data_type)
#define MIN_VALUE MIN_VALUE_TYPE(DATA_TYPE)
#ifdef SOFTMAX_X
/** 3-pass softmax in the x dimension.
*
* List of preprocessors:
* - DATA_TYPE: the input/output data type.
* - TMP_DATA_TYPE: the data type used for computing and temporary tensor storage.
* If DATA_TYPE is quantized, TMP_DATA_TYPE is floating-point, otherwise TMP_DATA_TYPE is the same as DATA_TYPE.
* - IS_LOG (optional): indicating whether this is log softmax.
* - LENGTH: the number of elements in softmax axis in the input/output tensors.
* - BETA: the beta coefficient.
* - IS_QUANTIZED (optional): indicating whether the input/output data type is quantized data.
* - VEC_SIZE: the size of the vector.
*
* Additional preprocessors in case IS_QUANTIZED is present:
* - SRC_SCALE and SRC_OFFSET: the quantization information of the source tensor.
* - DST_SCALE and DST_OFFSET: the quantization information of the destination tensor.
*
* @param[in] src_ptr Pointer to the source tensor.
* @param[in] src_stride_0 Stride in bytes of the source tensor in the dimension corresponding to global ID 0.
* @param[in] src_stride_1 Stride in bytes of the source tensor in the dimension corresponding to global ID 1.
* @param[in] src_stride_2 Stride in bytes of the source tensor in the dimension corresponding to global ID 2.
* @param[in] src_offset_first_element Offset of the first element in the source tensor.
* @param[in] dst_ptr Pointer to the destination tensor.
* @param[in] dst_stride_0 Stride in bytes of the destination tensor in the dimension corresponding to global ID 0.
* @param[in] dst_stride_1 Stride in bytes of the destination tensor in the dimension corresponding to global ID 1.
* @param[in] dst_stride_2 Stride in bytes of the destination tensor in the dimension corresponding to global ID 2.
* @param[in] dst_offset_first_element Offset of the first element in the destination tensor.
* @param[in] tmp_ptr Pointer to the temporary tensor.
* @param[in] tmp_stride_0 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 0.
* @param[in] tmp_stride_1 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 1.
* @param[in] tmp_stride_2 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 2.
* @param[in] tmp_offset_first_element Offset of the first element in the temporary tensor.
*/
__kernel void softmax_x(
__global uchar *src_ptr,
uint src_stride_0,
uint src_stride_1,
uint src_stride_2,
uint src_offset_first_element,
__global uchar *dst_ptr,
uint dst_stride_0,
uint dst_stride_1,
uint dst_stride_2,
uint dst_offset_first_element
#ifdef IS_QUANTIZED
,
__global uchar *tmp_ptr,
uint tmp_stride_0,
uint tmp_stride_1,
uint tmp_stride_2,
uint tmp_offset_first_element
#endif // IS_QUANTIZED
)
{
const int dim_0 = get_global_id(0);
const int dim_1 = get_global_id(1);
const int dim_2 = get_global_id(2);
src_ptr += src_offset_first_element + dim_2 * src_stride_2 + dim_1 * src_stride_1 + dim_0 * src_stride_0;
dst_ptr += dst_offset_first_element + dim_2 * dst_stride_2 + dim_1 * dst_stride_1 + dim_0 * dst_stride_0;
#ifdef IS_QUANTIZED
tmp_ptr += tmp_offset_first_element + dim_2 * tmp_stride_2 + dim_1 * tmp_stride_1 + dim_0 * tmp_stride_0;
#else // IS_QUANTIZED
__global uchar *tmp_ptr = dst_ptr;
#endif // IS_QUANTIZED
// Calculate max value.
DATA_TYPE max_value = MIN_VALUE;
int i = 0;
for (i = 0; i < LENGTH - VEC_SIZE; i += VEC_SIZE)
{
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE)));
max_value = max(max_value, MAX_REDUCE(data, VEC_SIZE));
}
for (; i < LENGTH; ++i)
{
DATA_TYPE data = *(__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE));
max_value = max(max_value, data);
}
// Regularize the data.
TMP_DATA_TYPE sum_value = 0;
#ifdef IS_QUANTIZED
TMP_DATA_TYPE max_value_f = (CONVERT(max_value, TMP_DATA_TYPE) - SRC_OFFSET) * SRC_SCALE;
TMP_DATA_TYPE regularize_offset = -SRC_OFFSET * SRC_SCALE * (TMP_DATA_TYPE)BETA - max_value_f * (TMP_DATA_TYPE)BETA;
# define REGULARIZE(x) ((x) * SRC_SCALE * (TMP_DATA_TYPE)BETA + regularize_offset)
#else // IS_QUANTIZED
# define REGULARIZE(x) (((x) - max_value) * (TMP_DATA_TYPE)BETA)
#endif // IS_QUANTIZED
for (i = 0; i < LENGTH - VEC_SIZE; i += VEC_SIZE)
{
VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = CONVERT(VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE))), VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE));
data = REGULARIZE(data);
#ifdef IS_LOG
sum_value += SUM_REDUCE(exp(data), VEC_SIZE);
#else // IS_LOG
data = exp(data);
sum_value += SUM_REDUCE(data, VEC_SIZE);
#endif // IS_LOG
VSTORE(VEC_SIZE)(data, 0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE)));
}
for (; i < LENGTH; ++i)
{
TMP_DATA_TYPE data = CONVERT(*(__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE)), TMP_DATA_TYPE);
data = REGULARIZE(data);
#ifdef IS_LOG
sum_value += exp(data);
#else // IS_LOG
data = exp(data);
sum_value += data;
#endif // IS_LOG
*(__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE)) = data;
}
#undef REGULARIZE
// Normalize the data.
#ifdef IS_QUANTIZED
# if IS_LOG
TMP_DATA_TYPE norm_offset = -log(sum_value) + DST_OFFSET;
# define NORMALIZE(SIZE, x) CONVERT_SAT_ROUND((x) / DST_SCALE + norm_offset, VEC_DATA_TYPE(DATA_TYPE, SIZE), rte)
# else // IS_LOG
TMP_DATA_TYPE norm_div = sum_value * DST_SCALE;
# define NORMALIZE(SIZE, x) CONVERT_SAT(add_sat(CONVERT_SAT_ROUND((x) / norm_div, VEC_DATA_TYPE(int, SIZE), rte), DST_OFFSET), VEC_DATA_TYPE(DATA_TYPE, SIZE))
# endif // IS_LOG
#else // IS_QUANTIZED
# if IS_LOG
# define NORMALIZE(SIZE, x) ((x) - log(sum_value))
# else // IS_LOG
# define NORMALIZE(SIZE, x) ((x) / sum_value)
# endif // IS_LOG
#endif // IS_QUANTIZED
for (i = 0; i < LENGTH - VEC_SIZE; i += VEC_SIZE)
{
VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE)));
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) result = NORMALIZE(VEC_SIZE, data);
VSTORE(VEC_SIZE)(result, 0, (__global DATA_TYPE *)(dst_ptr + i * sizeof(DATA_TYPE)));
}
for (; i < LENGTH; ++i)
{
TMP_DATA_TYPE data = *(__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE));
DATA_TYPE result = NORMALIZE(1, data);
*(__global DATA_TYPE *)(dst_ptr + i * sizeof(DATA_TYPE)) = result;
}
#undef NORMALIZE
}
#endif // SOFTMAX_X
#ifdef SOFTMAX_NON_X
/** 3-pass softmax in any dimension higher than the x dimension.
*
* List of preprocessors:
* - DATA_TYPE: the input/output data type.
* - TMP_DATA_TYPE: the data type used for computing and temporary tensor storage.
* If DATA_TYPE is quantized, TMP_DATA_TYPE is floating-point, otherwise TMP_DATA_TYPE is the same as DATA_TYPE.
* - IS_LOG (optional): indicating whether this is log softmax.
* - LENGTH: the number of elements in softmax axis in the input/output tensors.
* - BETA: the beta coefficient.
* - IS_QUANTIZED (optional): indicating whether the input/output data type is quantized data.
* - VEC_SIZE: the size of the vector.
* - VEC_SIZE_LEFTOVER: the size of the leftover part.
*
* Additional preprocessors in case IS_QUANTIZED is present:
* - SRC_SCALE and SRC_OFFSET: the quantization information of the source tensor.
* - DST_SCALE and DST_OFFSET: the quantization information of the destination tensor.
*
* @param[in] src_ptr Pointer to the source tensor.
* @param[in] src_stride_0 Stride in bytes of the source tensor in the dimension corresponding to global ID 0.
* @param[in] src_stride_1 Stride in bytes of the source tensor in the dimension corresponding to global ID 1.
* @param[in] src_stride_2 Stride in bytes of the source tensor in the dimension corresponding to global ID 2.
* @param[in] src_offset_first_element Offset of the first element in the source tensor.
* @param[in] dst_ptr Pointer to the destination tensor.
* @param[in] dst_stride_0 Stride in bytes of the destination tensor in the dimension corresponding to global ID 0.
* @param[in] dst_stride_1 Stride in bytes of the destination tensor in the dimension corresponding to global ID 1.
* @param[in] dst_stride_2 Stride in bytes of the destination tensor in the dimension corresponding to global ID 2.
* @param[in] dst_offset_first_element Offset of the first element in the destination tensor.
* @param[in] tmp_ptr Pointer to the temporary tensor.
* @param[in] tmp_stride_0 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 0.
* @param[in] tmp_stride_1 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 1.
* @param[in] tmp_stride_2 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 2.
* @param[in] tmp_offset_first_element Offset of the first element in the temporary tensor.
*/
__kernel void softmax_non_x(
__global uchar *src_ptr,
uint src_stride_0,
uint src_stride_1,
uint src_stride_2,
uint src_offset_first_element,
__global uchar *dst_ptr,
uint dst_stride_0,
uint dst_stride_1,
uint dst_stride_2,
uint dst_offset_first_element,
__global uchar *tmp_ptr,
uint tmp_stride_0,
uint tmp_stride_1,
uint tmp_stride_2,
uint tmp_offset_first_element,
uint src_stride_axis,
uint dst_stride_axis
)
{
const int dim_0 = max((int)get_global_id(0) * VEC_SIZE - (VEC_SIZE - VEC_SIZE_LEFTOVER) % VEC_SIZE, 0);
const int dim_1 = get_global_id(1);
const int dim_2 = get_global_id(2);
src_ptr += src_offset_first_element + dim_2 * src_stride_2 + dim_1 * src_stride_1 + dim_0 * src_stride_0;
dst_ptr += dst_offset_first_element + dim_2 * dst_stride_2 + dim_1 * dst_stride_1 + dim_0 * dst_stride_0;
tmp_ptr += tmp_offset_first_element + dim_2 * tmp_stride_2 + dim_1 * tmp_stride_1 + dim_0 * tmp_stride_0;
// In case of processing quantized data, i.e. DATA_TYPE is smaller than TMP_DATA_TYPE:
//
// In the first pass (finding max), the quantized data is copied from the input tensor to the temporary tensor.
// Dequantization is not needed to find the max value and since dequantization widens the data, we defer it
// to the second pass pass to reduce memory bandwidth of the first pass.
//
// In the second pass, it reads the quantized data from the temporary tensor and writes the dequantized data
// back to the temporary tensor.
//
// To avoid dequantized data overwritting the unprocessed quantized data in the temporary tensor,
// this extra offset is introduced to store the quantized data at the end of the temporary tensor.
//
// Note: Another approach is to perform the second pass in reverse order, but for unexplanable reason
// it doesn't work in some devices.
uint tmp_extra_offset = LENGTH * VEC_SIZE * (sizeof(TMP_DATA_TYPE) - sizeof(DATA_TYPE));
// Calculate max value and store the input data to the temporary tensor in suitable format.
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) max_value = MIN_VALUE;
int i = 0;
for (i = 0; i < LENGTH; ++i)
{
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_ptr + i * src_stride_axis));
max_value = max(max_value, data);
VSTORE(VEC_SIZE)(data, 0, (__global DATA_TYPE *)(tmp_ptr + tmp_extra_offset + i * VEC_SIZE * sizeof(DATA_TYPE)));
}
// Regularize the data.
VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) sum_value = 0;
#ifdef IS_QUANTIZED
VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) max_value_f = (CONVERT(max_value, VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE)) - SRC_OFFSET) * SRC_SCALE;
VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) regularize_offset = -SRC_OFFSET * SRC_SCALE * (TMP_DATA_TYPE)BETA - max_value_f * (TMP_DATA_TYPE)BETA;
# define REGULARIZE(x) ((x) * SRC_SCALE * (TMP_DATA_TYPE)BETA + regularize_offset)
#else // IS_QUANTIZED
# define REGULARIZE(x) (((x) - max_value) * (TMP_DATA_TYPE)BETA)
#endif // IS_QUANTIZED
for (i = 0; i < LENGTH; ++i)
{
VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = CONVERT(VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(tmp_ptr + tmp_extra_offset + i * VEC_SIZE * sizeof(DATA_TYPE))), VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE));
data = REGULARIZE(data);
#ifdef IS_LOG
sum_value += exp(data);
#else // IS_LOG
data = exp(data);
sum_value += data;
#endif // IS_LOG
VSTORE(VEC_SIZE)(data, 0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * VEC_SIZE * sizeof(TMP_DATA_TYPE)));
}
#undef REGULARIZE
// Normalize the data.
#ifdef IS_QUANTIZED
# if IS_LOG
VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) norm_offset = -log(sum_value) + DST_OFFSET;
# define NORMALIZE(x) CONVERT_SAT_ROUND((x) / DST_SCALE + norm_offset, VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE), rte)
# else // IS_LOG
VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) norm_div = sum_value * DST_SCALE;
# define NORMALIZE(x) CONVERT_SAT(add_sat(CONVERT_SAT_ROUND((x) / norm_div, VEC_DATA_TYPE(int, VEC_SIZE), rte), DST_OFFSET), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))
# endif // IS_LOG
#else // IS_QUANTIZED
# if IS_LOG
# define NORMALIZE(x) ((x) - log(sum_value))
# else // IS_LOG
# define NORMALIZE(x) ((x) / sum_value)
# endif // IS_LOG
#endif // IS_QUANTIZED
for (i = 0; i < LENGTH; ++i)
{
VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * VEC_SIZE * sizeof(TMP_DATA_TYPE)));
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) result0 = NORMALIZE(data);
STORE_VECTOR_SELECT(result, DATA_TYPE, dst_ptr + i * dst_stride_axis, VEC_SIZE, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0)
}
#undef NORMALIZE
}
#endif // SOFTMAX_NON_X
#undef MIN_VALUE
#undef MIN_VALUE_TYPE
#undef MIN_VALUE_TYPE_STR
#undef MIN_VALUE_float
#undef MIN_VALUE_half
#undef MIN_VALUE_char
#undef MIN_VALUE_uchar