Optimize CL softmax

* The new softmax implementation consists of only a single kernel.
  - There are 2 versions of softmax, one for the x dimension
    and one for any other dimensions.
  - Softmax kernel handles both native and quantized data type.

Resolves: COMPMID-6447
Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Change-Id: I4a9ae5bc63f78aebeaa85ee48a0d102c9c245eda
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10489
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/common/softmax_layer.cl b/src/core/CL/cl_kernels/common/softmax_layer.cl
index 4d2d89d..58c4589 100644
--- a/src/core/CL/cl_kernels/common/softmax_layer.cl
+++ b/src/core/CL/cl_kernels/common/softmax_layer.cl
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,511 +21,344 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
+
 #include "helpers.h"
 
-#if defined(DATA_TYPE) && defined(MIN_VALUE) && defined(VECTOR_SIZE) && defined(VECTOR_SIZE_LEFTOVER)
+#define MIN_VALUE_float -FLT_MAX
+#define MIN_VALUE_half  -HALF_MAX
+#define MIN_VALUE_char  CHAR_MIN
+#define MIN_VALUE_uchar 0
 
-/** Divides all the values of the input tensor by the sum calculated from softmax_layer_shift_exp_sum kernel.
+#define MIN_VALUE_TYPE_STR(data_type) MIN_VALUE_##data_type
+#define MIN_VALUE_TYPE(data_type) MIN_VALUE_TYPE_STR(data_type)
+#define MIN_VALUE MIN_VALUE_TYPE(DATA_TYPE)
+
+#ifdef SOFTMAX_X
+
+/** 3-pass softmax in the x dimension.
  *
- * @note Datatype must be given as a preprocessor argument using -DDATA_TYPE, e.g. -DDATA_TYPE=float
- * @note The zero value for the given data type must be given as a preprocessor argument using -DMIN_VALUE, e.g. -DMIN_VALUE=0
- * @note Vector size should be given as a preprocessor argument using -DVECTOR_SIZE=size. e.g. -DVECTOR_SIZE=16
- * @note Leftover vector size has to be passed at compile time using -DVECTOR_SIZE_LEFTOVER. e.g. -DVECTOR_SIZE_LEFTOVER=3. It is defined as the remainder between the input's first dimension and VECTOR_SIZE
- * @note In case of log softmax, -DLOG_SOFTMAX must be passed.
+ * List of preprocessors:
+ *   - DATA_TYPE: the input/output data type.
+ *   - TMP_DATA_TYPE: the data type used for computing and temporary tensor storage.
+ *     If DATA_TYPE is quantized, TMP_DATA_TYPE is floating-point, otherwise TMP_DATA_TYPE is the same as DATA_TYPE.
+ *   - IS_LOG (optional): indicating whether this is log softmax.
+ *   - LENGTH: the number of elements in softmax axis in the input/output tensors.
+ *   - BETA: the beta coefficient.
+ *   - IS_QUANTIZED (optional): indicating whether the input/output data type is quantized data.
+ *   - VEC_SIZE: the size of the vector.
  *
- * @param[in]  src_ptr                           Pointer to the source tensor slice. Supported data types: F16/F32
- * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
- * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
- * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
- * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
- * @param[in]  sum_ptr                           Pointer to the sum values tensor slice. Supported data types: same as @p src_ptr
- * @param[in]  sum_stride_x                      Stride of the sum values tensor in X dimension (in bytes)
- * @param[in]  sum_step_x                        sum_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  sum_stride_y                      Stride of the sum values tensor in Y dimension (in bytes)
- * @param[in]  sum_step_y                        sum_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  sum_stride_z                      Stride of the sum values tensor in Z dimension (in bytes)
- * @param[in]  sum_step_z                        sum_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  sum_offset_first_element_in_bytes The offset of the first element in the sum values tensor
- * @param[out] dst_ptr                           Pointer to the destination tensor slice. Supported data types: same as @p src_ptr
- * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
- * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
- * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ * Additional preprocessors in case IS_QUANTIZED is present:
+ *   - SRC_SCALE and SRC_OFFSET: the quantization information of the source tensor.
+ *   - DST_SCALE and DST_OFFSET: the quantization information of the destination tensor.
+ *
+ * @param[in] src_ptr                  Pointer to the source tensor.
+ * @param[in] src_stride_0             Stride in bytes of the source tensor in the dimension corresponding to global ID 0.
+ * @param[in] src_stride_1             Stride in bytes of the source tensor in the dimension corresponding to global ID 1.
+ * @param[in] src_stride_2             Stride in bytes of the source tensor in the dimension corresponding to global ID 2.
+ * @param[in] src_offset_first_element Offset of the first element in the source tensor.
+ * @param[in] dst_ptr                  Pointer to the destination tensor.
+ * @param[in] dst_stride_0             Stride in bytes of the destination tensor in the dimension corresponding to global ID 0.
+ * @param[in] dst_stride_1             Stride in bytes of the destination tensor in the dimension corresponding to global ID 1.
+ * @param[in] dst_stride_2             Stride in bytes of the destination tensor in the dimension corresponding to global ID 2.
+ * @param[in] dst_offset_first_element Offset of the first element in the destination tensor.
+ * @param[in] tmp_ptr                  Pointer to the temporary tensor.
+ * @param[in] tmp_stride_0             Stride in bytes of the temporary tensor in the dimension corresponding to global ID 0.
+ * @param[in] tmp_stride_1             Stride in bytes of the temporary tensor in the dimension corresponding to global ID 1.
+ * @param[in] tmp_stride_2             Stride in bytes of the temporary tensor in the dimension corresponding to global ID 2.
+ * @param[in] tmp_offset_first_element Offset of the first element in the temporary tensor.
  */
-__kernel void softmax_layer_norm(
-    TENSOR3D_DECLARATION(src),
-    TENSOR3D_DECLARATION(sum),
-    TENSOR3D_DECLARATION(dst))
+__kernel void softmax_x(
+    __global uchar *src_ptr,
+    uint src_stride_0,
+    uint src_stride_1,
+    uint src_stride_2,
+    uint src_offset_first_element,
+
+    __global uchar *dst_ptr,
+    uint dst_stride_0,
+    uint dst_stride_1,
+    uint dst_stride_2,
+    uint dst_offset_first_element
+
+#ifdef IS_QUANTIZED
+    ,
+    __global uchar *tmp_ptr,
+    uint tmp_stride_0,
+    uint tmp_stride_1,
+    uint tmp_stride_2,
+    uint tmp_offset_first_element
+#endif // IS_QUANTIZED
+)
 {
-    const int x_offs = max((int)(get_global_id(0) * VECTOR_SIZE - (VECTOR_SIZE - VECTOR_SIZE_LEFTOVER) % VECTOR_SIZE), 0) * sizeof(DATA_TYPE);
+    const int dim_0 = get_global_id(0);
+    const int dim_1 = get_global_id(1);
+    const int dim_2 = get_global_id(2);
 
-    __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x_offs + get_global_id(1) * src_stride_y + get_global_id(2) * src_stride_z;
-    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_offs + get_global_id(1) * dst_stride_y + get_global_id(2) * dst_stride_z;
+    src_ptr += src_offset_first_element + dim_2 * src_stride_2 + dim_1 * src_stride_1 + dim_0 * src_stride_0;
+    dst_ptr += dst_offset_first_element + dim_2 * dst_stride_2 + dim_1 * dst_stride_1 + dim_0 * dst_stride_0;
 
-    Image sum = CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(sum);
+#ifdef IS_QUANTIZED
+    tmp_ptr += tmp_offset_first_element + dim_2 * tmp_stride_2 + dim_1 * tmp_stride_1 + dim_0 * tmp_stride_0;
+#else // IS_QUANTIZED
+    __global uchar *tmp_ptr = dst_ptr;
+#endif // IS_QUANTIZED
 
-    // Load max value of 1D logits vector (row)
-    DATA_TYPE sum_val = *((__global DATA_TYPE *)offset(&sum, 0, get_global_id(1)));
-    VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
-    data0 = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)src_addr);
+    // Calculate max value.
+    DATA_TYPE max_value = MIN_VALUE;
+    int i = 0;
 
-#if defined(LOG_SOFTMAX)
-    sum_val = log(sum_val);
-    data0 -= sum_val;
-#else  // defined(LOG_SOFTMAX)
-    data0 /= sum_val;
-#endif // defined(LOG_SOFTMAX)
+    for (i = 0; i < LENGTH - VEC_SIZE; i += VEC_SIZE)
+    {
+        VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE)));
 
-    STORE_VECTOR_SELECT(data, DATA_TYPE, dst_addr, VECTOR_SIZE, VECTOR_SIZE_LEFTOVER, VECTOR_SIZE_LEFTOVER != 0 && get_global_id(0) == 0)
+        max_value = max(max_value, MAX_REDUCE(data, VEC_SIZE));
+    }
+
+    for (; i < LENGTH; ++i)
+    {
+        DATA_TYPE data = *(__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE));
+
+        max_value = max(max_value, data);
+    }
+
+    // Regularize the data.
+    TMP_DATA_TYPE sum_value = 0;
+
+#ifdef IS_QUANTIZED
+    TMP_DATA_TYPE max_value_f = (CONVERT(max_value, TMP_DATA_TYPE) - SRC_OFFSET) * SRC_SCALE;
+    TMP_DATA_TYPE regularize_offset = -SRC_OFFSET * SRC_SCALE * (TMP_DATA_TYPE)BETA - max_value_f * (TMP_DATA_TYPE)BETA;
+# define REGULARIZE(x) ((x) * SRC_SCALE * (TMP_DATA_TYPE)BETA + regularize_offset)
+#else // IS_QUANTIZED
+# define REGULARIZE(x) (((x) - max_value) * (TMP_DATA_TYPE)BETA)
+#endif // IS_QUANTIZED
+
+    for (i = 0; i < LENGTH - VEC_SIZE; i += VEC_SIZE)
+    {
+        VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = CONVERT(VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE))), VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE));
+
+        data = REGULARIZE(data);
+
+#ifdef IS_LOG
+        sum_value += SUM_REDUCE(exp(data), VEC_SIZE);
+#else // IS_LOG
+        data = exp(data);
+        sum_value += SUM_REDUCE(data, VEC_SIZE);
+#endif // IS_LOG
+
+        VSTORE(VEC_SIZE)(data, 0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE)));
+    }
+
+    for (; i < LENGTH; ++i)
+    {
+        TMP_DATA_TYPE data = CONVERT(*(__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE)), TMP_DATA_TYPE);
+
+        data = REGULARIZE(data);
+
+#ifdef IS_LOG
+        sum_value += exp(data);
+#else // IS_LOG
+        data = exp(data);
+        sum_value += data;
+#endif // IS_LOG
+
+        *(__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE)) = data;
+    }
+
+#undef REGULARIZE
+
+    // Normalize the data.
+#ifdef IS_QUANTIZED
+# if IS_LOG
+    TMP_DATA_TYPE norm_offset = -log(sum_value) + DST_OFFSET;
+#  define NORMALIZE(SIZE, x) CONVERT_SAT_ROUND((x) / DST_SCALE + norm_offset, VEC_DATA_TYPE(DATA_TYPE, SIZE), rte)
+# else // IS_LOG
+    TMP_DATA_TYPE norm_div = sum_value * DST_SCALE;
+#  define NORMALIZE(SIZE, x) CONVERT_SAT(add_sat(CONVERT_SAT_ROUND((x) / norm_div, VEC_DATA_TYPE(int, SIZE), rte), DST_OFFSET), VEC_DATA_TYPE(DATA_TYPE, SIZE))
+#  endif // IS_LOG
+#else // IS_QUANTIZED
+# if IS_LOG
+#  define NORMALIZE(SIZE, x) ((x) - log(sum_value))
+# else // IS_LOG
+#  define NORMALIZE(SIZE, x) ((x) / sum_value)
+# endif // IS_LOG
+#endif // IS_QUANTIZED
+
+    for (i = 0; i < LENGTH - VEC_SIZE; i += VEC_SIZE)
+    {
+        VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE)));
+
+        VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) result = NORMALIZE(VEC_SIZE, data);
+
+        VSTORE(VEC_SIZE)(result, 0, (__global DATA_TYPE *)(dst_ptr + i * sizeof(DATA_TYPE)));
+    }
+
+    for (; i < LENGTH; ++i)
+    {
+        TMP_DATA_TYPE data = *(__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE));
+
+        DATA_TYPE result = NORMALIZE(1, data);
+
+        *(__global DATA_TYPE *)(dst_ptr + i * sizeof(DATA_TYPE)) = result;
+    }
+
+#undef NORMALIZE
 }
 
-#if defined(SRC_WIDTH) && defined(LOG_VECTOR_SIZE) && defined(MINVAL)
+#endif // SOFTMAX_X
 
-/* Number of workitems in dimension 0. */
-#if !defined(GRID_SIZE)
-#define GRID_SIZE 1
-#endif /* !defined(GRID_SIZE) */
+#ifdef SOFTMAX_NON_X
 
-#define VEC_TYPE VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
-#define SELECT_TYPE SELECT_VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
-
-/** Identifies the maximum value across the 1st dimension and shifts the values of the input tensor by this maximum value,
- * then gets the exponent of each element as sums all elements across each row.
+/** 3-pass softmax in any dimension higher than the x dimension.
  *
- * @note Datatype must be given as a preprocessor argument using -DDATA_TYPE, e.g. -DDATA_TYPE=float
- * @note The zero value for the given data type must be given as a preprocessor argument using -DMIN_VALUE, e.g. -DMIN_VALUE=0
- * @note Vector size should be given as a preprocessor argument using -DVECTOR_SIZE=size. e.g. -DVECTOR_SIZE=16
- * @note Leftover vector size has to be passed at compile time using -DVECTOR_SIZE_LEFTOVER. e.g. -DVECTOR_SIZE_LEFTOVER=3. It is defined as the remainder between the input's first dimension and VECTOR_SIZE
- * @note In case the input is not a multiple of VECTOR_SIZE (2,4,8,16) -DNON_MULTIPLE_OF_VECTOR_SIZE must be passed.
- * @note Beta can be optionally passed at compile time using -DBETA (by default, it is 1.0).
- * @note In case of log softmax, -DLOG_SOFTMAX must be passed.
- * @note Based on the data type, the minimum possible value must be passed using -DMINVAL. For float it should be defined as -FLT_MAX, while for half it should be -HALF_MAX
+ * List of preprocessors:
+ *   - DATA_TYPE: the input/output data type.
+ *   - TMP_DATA_TYPE: the data type used for computing and temporary tensor storage.
+ *     If DATA_TYPE is quantized, TMP_DATA_TYPE is floating-point, otherwise TMP_DATA_TYPE is the same as DATA_TYPE.
+ *   - IS_LOG (optional): indicating whether this is log softmax.
+ *   - LENGTH: the number of elements in softmax axis in the input/output tensors.
+ *   - BETA: the beta coefficient.
+ *   - IS_QUANTIZED (optional): indicating whether the input/output data type is quantized data.
+ *   - VEC_SIZE: the size of the vector.
+ *   - VEC_SIZE_LEFTOVER: the size of the leftover part.
  *
- * @param[in]  src_ptr                            Pointer to the source tensor slice. Supported data types: F16/F32
- * @param[in]  src_stride_x                       Stride of the source tensor in X dimension (in bytes)
- * @param[in]  src_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src_stride_y                       Stride of the source tensor in Y dimension (in bytes)
- * @param[in]  src_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  src_stride_z                       Stride of the source tensor in Z dimension (in bytes)
- * @param[in]  src_step_z                         src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  src_offset_first_element_in_bytes  The offset of the first element in the source tensor
- * @param[in]  maxo_ptr                           Pointer to the max values tensor slice. Supported data types: same as @p src_ptr
- * @param[in]  maxo_stride_x                      Stride of the max values tensor in X dimension (in bytes)
- * @param[in]  maxo_step_x                        max_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  maxo_stride_y                      Stride of the max values tensor in Y dimension (in bytes)
- * @param[in]  maxo_step_y                        max_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  maxo_stride_z                      Stride of the max values tensor in Z dimension (in bytes)
- * @param[in]  maxo_step_z                        max_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  maxo_offset_first_element_in_bytes The offset of the first element in the max values tensor
- * @param[out] dst_ptr                            Pointer to the destination tensor slice. Supported data types: same as @p src_ptr
- * @param[in]  dst_stride_x                       Stride of the destination tensor in X dimension (in bytes)
- * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  dst_stride_y                       Stride of the destination tensor in Y dimension (in bytes)
- * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  dst_step_z                         dst_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination tensor
- * @param[out] sum_ptr                            Pointer to the sum values tensor slice. Supported data types: same as @p src_ptr
- * @param[in]  sum_stride_x                       Stride of the sum values tensor in X dimension (in bytes)
- * @param[in]  sum_step_x                         sum_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  sum_stride_y                       Stride of the sum values tensor in Y dimension (in bytes)
- * @param[in]  sum_step_y                         sum_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  sum_stride_z                       Stride of the sum values tensor in Z dimension (in bytes)
- * @param[in]  sum_step_z                         sum_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  sum_offset_first_element_in_bytes  The offset of the first element in the sum values tensor
+ * Additional preprocessors in case IS_QUANTIZED is present:
+ *   - SRC_SCALE and SRC_OFFSET: the quantization information of the source tensor.
+ *   - DST_SCALE and DST_OFFSET: the quantization information of the destination tensor.
+ *
+ * @param[in] src_ptr                  Pointer to the source tensor.
+ * @param[in] src_stride_0             Stride in bytes of the source tensor in the dimension corresponding to global ID 0.
+ * @param[in] src_stride_1             Stride in bytes of the source tensor in the dimension corresponding to global ID 1.
+ * @param[in] src_stride_2             Stride in bytes of the source tensor in the dimension corresponding to global ID 2.
+ * @param[in] src_offset_first_element Offset of the first element in the source tensor.
+ * @param[in] dst_ptr                  Pointer to the destination tensor.
+ * @param[in] dst_stride_0             Stride in bytes of the destination tensor in the dimension corresponding to global ID 0.
+ * @param[in] dst_stride_1             Stride in bytes of the destination tensor in the dimension corresponding to global ID 1.
+ * @param[in] dst_stride_2             Stride in bytes of the destination tensor in the dimension corresponding to global ID 2.
+ * @param[in] dst_offset_first_element Offset of the first element in the destination tensor.
+ * @param[in] tmp_ptr                  Pointer to the temporary tensor.
+ * @param[in] tmp_stride_0             Stride in bytes of the temporary tensor in the dimension corresponding to global ID 0.
+ * @param[in] tmp_stride_1             Stride in bytes of the temporary tensor in the dimension corresponding to global ID 1.
+ * @param[in] tmp_stride_2             Stride in bytes of the temporary tensor in the dimension corresponding to global ID 2.
+ * @param[in] tmp_offset_first_element Offset of the first element in the temporary tensor.
  */
-__kernel void softmax_layer_max_shift_exp_sum_serial(
-    TENSOR3D_DECLARATION(src),
-    TENSOR3D_DECLARATION(maxo),
-    TENSOR3D_DECLARATION(dst),
-    TENSOR3D_DECLARATION(sum))
+__kernel void softmax_non_x(
+    __global uchar *src_ptr,
+    uint src_stride_0,
+    uint src_stride_1,
+    uint src_stride_2,
+    uint src_offset_first_element,
+
+    __global uchar *dst_ptr,
+    uint dst_stride_0,
+    uint dst_stride_1,
+    uint dst_stride_2,
+    uint dst_offset_first_element,
+
+    __global uchar *tmp_ptr,
+    uint tmp_stride_0,
+    uint tmp_stride_1,
+    uint tmp_stride_2,
+    uint tmp_offset_first_element,
+
+    uint src_stride_axis,
+    uint dst_stride_axis
+)
 {
-    __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(1) * src_stride_y + get_global_id(2) * src_stride_z;
-    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + get_global_id(1) * dst_stride_y + get_global_id(2) * dst_stride_z;
+    const int dim_0 = max((int)get_global_id(0) * VEC_SIZE - (VEC_SIZE - VEC_SIZE_LEFTOVER) % VEC_SIZE, 0);
+    const int dim_1 = get_global_id(1);
+    const int dim_2 = get_global_id(2);
 
-    Image maxo = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(maxo);
-    Image sum  = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(sum);
+    src_ptr += src_offset_first_element + dim_2 * src_stride_2 + dim_1 * src_stride_1 + dim_0 * src_stride_0;
+    dst_ptr += dst_offset_first_element + dim_2 * dst_stride_2 + dim_1 * dst_stride_1 + dim_0 * dst_stride_0;
+    tmp_ptr += tmp_offset_first_element + dim_2 * tmp_stride_2 + dim_1 * tmp_stride_1 + dim_0 * tmp_stride_0;
 
-#ifdef BETA
-    // Initialize beta
-    VEC_TYPE beta = (VEC_TYPE)BETA;
-#endif /* BETA */
+    // Calculate max value and store the input data to the temporary tensor in suitable format.
+    VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) max_value = MIN_VALUE;
+    int i = 0;
 
-    // Initialize local maximum
-    VEC_TYPE max_val_vec = (VEC_TYPE)(MINVAL);
-
-#ifdef NON_MULTIPLE_OF_VECTOR_SIZE
-    VEC_TYPE data    = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)src_addr);
-    SELECT_TYPE widx = (SELECT_TYPE)VECTOR_SIZE_LEFTOVER > VEC_OFFS(SELECT_DATA_TYPE(DATA_TYPE), VECTOR_SIZE);
-    max_val_vec      = max(max_val_vec, select((VEC_TYPE)(MINVAL), data, widx));
-#endif /* NON_MULTIPLE_OF_VECTOR_SIZE */
-
-    for(uint i = VECTOR_SIZE_LEFTOVER; i < SRC_WIDTH; i += VECTOR_SIZE)
+    for (i = 0; i < LENGTH; ++i)
     {
-        VEC_TYPE data = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(src_addr + i * sizeof(DATA_TYPE)));
-        max_val_vec   = max(data, max_val_vec);
+        VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_ptr + i * src_stride_axis));
+
+        max_value = max(max_value, data);
+
+        VSTORE(VEC_SIZE)(data, 0, (__global DATA_TYPE *)(tmp_ptr + i * VEC_SIZE * sizeof(DATA_TYPE)));
     }
 
-    // Perform max reduction
-    DATA_TYPE max_val                 = MAX_REDUCE(max_val_vec, VECTOR_SIZE);
-    *((__global DATA_TYPE *)maxo.ptr) = max_val;
+    // Regularize the data.
+    VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) sum_value = 0;
 
-    /* Second section */
+#ifdef IS_QUANTIZED
+    VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) max_value_f = (CONVERT(max_value, VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE)) - SRC_OFFSET) * SRC_SCALE;
+    VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) regularize_offset = -SRC_OFFSET * SRC_SCALE * (TMP_DATA_TYPE)BETA - max_value_f * (TMP_DATA_TYPE)BETA;
+# define REGULARIZE(x) ((x) * SRC_SCALE * (TMP_DATA_TYPE)BETA + regularize_offset)
+#else // IS_QUANTIZED
+# define REGULARIZE(x) (((x) - max_value) * (TMP_DATA_TYPE)BETA)
+#endif // IS_QUANTIZED
 
-    // Set sum vector
-    VEC_TYPE sum1D = 0;
-
-#ifdef NON_MULTIPLE_OF_VECTOR_SIZE
-    data -= max_val;
-#ifdef BETA
-    data *= beta;
-#endif /* BETA */
-#ifdef LOG_SOFTMAX
-    VSTORE_PARTIAL(VECTOR_SIZE, VECTOR_SIZE_LEFTOVER)
-    (data, 0, (__global DATA_TYPE *)dst_addr);
-    data = exp(data);
-    data = select(0, data, widx);
-#else  /* LOG_SOFTMAX */
-    data = exp(data);
-    data = select(0, data, widx);
-    VSTORE_PARTIAL(VECTOR_SIZE, VECTOR_SIZE_LEFTOVER)
-    (data, 0, (__global DATA_TYPE *)dst_addr);
-#endif /* LOG_SOFTMAX */
-    sum1D += data;
-#endif /* NON_MULTIPLE_OF_VECTOR_SIZE */
-
-    // Shift values, exp and sum
-    for(uint i = VECTOR_SIZE_LEFTOVER; i < SRC_WIDTH; i += VECTOR_SIZE)
+    for (i = LENGTH - 1; i >= 0; --i)
     {
-        VEC_TYPE data = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(src_addr + i * sizeof(DATA_TYPE)));
-        data -= max_val;
-#ifdef BETA
-        data *= beta;
-#endif /* BETA */
-#ifdef LOG_SOFTMAX
-        VSTORE(VECTOR_SIZE)
-        (data, 0, (__global DATA_TYPE *)(dst_addr + i * sizeof(DATA_TYPE)));
+        // In case of processing quantized data, i.e. DATA_TYPE is smaller than TMP_DATA_TYPE:
+        //
+        // In the first pass (finding max), the quantized data is copied from the input tensor to the temporary tensor.
+        // Dequantization is not needed to find the max value and since dequantization widens the data, we defer it
+        // to the second pass pass to reduce memory bandwidth of the first pass.
+        //
+        // This pass reads the quantized data from the temporary tensor and writes the dequantized data
+        // back to the temporary tensor, hence we need to loop in reverse to avoid overwriting unprocessed data.
+
+        VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = CONVERT(VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(tmp_ptr + i * VEC_SIZE * sizeof(DATA_TYPE))), VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE));
+
+        data = REGULARIZE(data);
+
+#ifdef IS_LOG
+        sum_value += exp(data);
+#else // IS_LOG
         data = exp(data);
-#else  /* LOG_SOFTMAX */
-        data = exp(data);
-        VSTORE(VECTOR_SIZE)
-        (data, 0, (__global DATA_TYPE *)(dst_addr + i * sizeof(DATA_TYPE)));
-#endif /* LOG_SOFTMAX */
-        sum1D += data;
+        sum_value += data;
+#endif // IS_LOG
+
+        VSTORE(VEC_SIZE)(data, 0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * VEC_SIZE * sizeof(TMP_DATA_TYPE)));
     }
 
-    // Perform sum reduction
-    *((__global DATA_TYPE *)sum.ptr) = SUM_REDUCE(sum1D, VECTOR_SIZE);
+#undef REGULARIZE
+
+    // Normalize the data.
+#ifdef IS_QUANTIZED
+# if IS_LOG
+    VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) norm_offset = -log(sum_value) + DST_OFFSET;
+#  define NORMALIZE(x) CONVERT_SAT_ROUND((x) / DST_SCALE + norm_offset, VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE), rte)
+# else // IS_LOG
+    VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) norm_div = sum_value * DST_SCALE;
+#  define NORMALIZE(x) CONVERT_SAT(add_sat(CONVERT_SAT_ROUND((x) / norm_div, VEC_DATA_TYPE(int, VEC_SIZE), rte), DST_OFFSET), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))
+#  endif // IS_LOG
+#else // IS_QUANTIZED
+# if IS_LOG
+#  define NORMALIZE(x) ((x) - log(sum_value))
+# else // IS_LOG
+#  define NORMALIZE(x) ((x) / sum_value)
+# endif // IS_LOG
+#endif // IS_QUANTIZED
+
+    for (i = 0; i < LENGTH; ++i)
+    {
+        VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * VEC_SIZE * sizeof(TMP_DATA_TYPE)));
+
+        VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) result0 = NORMALIZE(data);
+
+        STORE_VECTOR_SELECT(result, DATA_TYPE, dst_ptr + i * dst_stride_axis, VEC_SIZE, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0)
+    }
+
+#undef NORMALIZE
 }
 
-/** Identifies the maximum value across the 1st dimension and shifts the values of the input tensor by this maximum value,
- * then gets the exponent of each element as sums all elements across each row.
- *
- * @note Datatype must be given as a preprocessor argument using -DDATA_TYPE, e.g. -DDATA_TYPE=float
- * @note The zero value for the given data type must be given as a preprocessor argument using -DMIN_VALUE, e.g. -DMIN_VALUE=0
- * @note Vector size should be given as a preprocessor argument using -DVECTOR_SIZE=size. e.g. -DVECTOR_SIZE=16
- * @note Leftover vector size has to be passed at compile time using -DVECTOR_SIZE_LEFTOVER. e.g. -DVECTOR_SIZE_LEFTOVER=3. It is defined as the remainder between the input's first dimension and VECTOR_SIZE
- * @note In case the input is not a multiple of VECTOR_SIZE (2,4,8,16) -DNON_MULTIPLE_OF_VECTOR_SIZE must be passed.
- * @note Beta can be optionally passed at compile time using -DBETA (by default, it is 1.0).
- * @note In case of log softmax, -DLOG_SOFTMAX must be passed.
- * @note Based on the data type, the minimum possible value must be passed using -DMINVAL. For float it should be defined as -FLT_MAX, while for half it should be -HALF_MAX
- *
- * @param[in]  src_ptr                            Pointer to the source tensor slice. Supported data types: F16/F32
- * @param[in]  src_stride_x                       Stride of the source tensor in X dimension (in bytes)
- * @param[in]  src_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src_stride_y                       Stride of the source tensor in Y dimension (in bytes)
- * @param[in]  src_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  src_stride_z                       Stride of the source tensor in Z dimension (in bytes)
- * @param[in]  src_step_z                         src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  src_offset_first_element_in_bytes  The offset of the first element in the source tensor
- * @param[in]  maxo_ptr                           Pointer to the max values tensor slice. Supported data types: same as @p src_ptr
- * @param[in]  maxo_stride_x                      Stride of the max values tensor in X dimension (in bytes)
- * @param[in]  maxo_step_x                        max_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  maxo_stride_y                      Stride of the max values tensor in Y dimension (in bytes)
- * @param[in]  maxo_step_y                        max_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  maxo_stride_z                      Stride of the max values tensor in Z dimension (in bytes)
- * @param[in]  maxo_step_z                        max_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  maxo_offset_first_element_in_bytes The offset of the first element in the max values tensor
- * @param[out] dst_ptr                            Pointer to the destination tensor slice. Supported data types: same as @p src_ptr
- * @param[in]  dst_stride_x                       Stride of the destination tensor in X dimension (in bytes)
- * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  dst_stride_y                       Stride of the destination tensor in Y dimension (in bytes)
- * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  dst_step_z                         dst_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination tensor
- * @param[out] sum_ptr                            Pointer to the sum values tensor slice. Supported data types: same as @p src_ptr
- * @param[in]  sum_stride_x                       Stride of the sum values tensor in X dimension (in bytes)
- * @param[in]  sum_step_x                         sum_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  sum_stride_y                       Stride of the sum values tensor in Y dimension (in bytes)
- * @param[in]  sum_step_y                         sum_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  sum_stride_z                       Stride of the sum values tensor in Z dimension (in bytes)
- * @param[in]  sum_step_z                         sum_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  sum_offset_first_element_in_bytes  The offset of the first element in the sum values tensor
- */
-__kernel void softmax_layer_max_shift_exp_sum_parallel(
-    TENSOR3D_DECLARATION(src),
-    TENSOR3D_DECLARATION(maxo),
-    TENSOR3D_DECLARATION(dst),
-    TENSOR3D_DECLARATION(sum))
-{
-    const uint lid    = get_local_id(0);
-    const uint x_offs = (VECTOR_SIZE_LEFTOVER + lid * VECTOR_SIZE) * sizeof(DATA_TYPE);
+#endif // SOFTMAX_NON_X
 
-    __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x_offs + get_global_id(1) * src_stride_y + get_global_id(2) * src_stride_z;
-    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_offs + get_global_id(1) * dst_stride_y + get_global_id(2) * dst_stride_z;
+#undef MIN_VALUE
+#undef MIN_VALUE_TYPE
+#undef MIN_VALUE_TYPE_STR
 
-    Image maxo = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(maxo);
-    Image sum  = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(sum);
-
-#ifdef BETA
-    // Initialize beta
-    VEC_TYPE beta = (VEC_TYPE)BETA;
-#endif /* BETA */
-
-    // Define one temporary vector per work-item.
-    __local VEC_TYPE tmp_local[GRID_SIZE];
-    __local DATA_TYPE max_local;
-
-    VEC_TYPE max_val_vec = (VEC_TYPE)(MINVAL);
-
-    // Number of iterations per work-item.
-    const uint width = (SRC_WIDTH / GRID_SIZE) >> LOG_VECTOR_SIZE;
-    // Calculate max of row
-    uint i = 0;
-    for(; i < width; ++i)
-    {
-        VEC_TYPE data_max = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(src_addr + (i * GRID_SIZE * VECTOR_SIZE) * sizeof(DATA_TYPE)));
-        max_val_vec       = max(data_max, max_val_vec);
-    }
-#ifdef NON_MULTIPLE_OF_GRID_SIZE
-    // How many work-items needed to complete the computation.
-    int boundary_workitems = (SRC_WIDTH % (GRID_SIZE * VECTOR_SIZE)) / VECTOR_SIZE;
-    if(lid < boundary_workitems)
-    {
-        VEC_TYPE data_max = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(src_addr + (i * GRID_SIZE * VECTOR_SIZE) * sizeof(DATA_TYPE)));
-        max_val_vec       = max(data_max, max_val_vec);
-    }
-#ifdef NON_MULTIPLE_OF_VECTOR_SIZE
-    SELECT_TYPE widx;
-    if(lid == 0)
-    {
-        // Handle non multiple of 4
-        VEC_TYPE data_max = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(src_addr - VECTOR_SIZE_LEFTOVER * sizeof(DATA_TYPE)));
-        widx              = (SELECT_TYPE)VECTOR_SIZE_LEFTOVER > VEC_OFFS(SELECT_DATA_TYPE(DATA_TYPE), VECTOR_SIZE);
-        max_val_vec       = max(max_val_vec, select((VEC_TYPE)(MINVAL), data_max, widx));
-    }
-#endif /* NON_MULTIPLE_OF_VECTOR_SIZE */
-#endif /* NON_MULTIPLE_OF_GRID_SIZE */
-    tmp_local[lid] = max_val_vec;
-
-    barrier(CLK_LOCAL_MEM_FENCE);
-
-    if(GRID_SIZE >= 256)
-    {
-        if(lid < 128)
-        {
-            tmp_local[lid] = max(tmp_local[lid + 128], tmp_local[lid]);
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if(GRID_SIZE >= 128)
-    {
-        if(lid < 64)
-        {
-            tmp_local[lid] = max(tmp_local[lid + 64], tmp_local[lid]);
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if(GRID_SIZE >= 64)
-    {
-        if(lid < 32)
-        {
-            tmp_local[lid] = max(tmp_local[lid + 32], tmp_local[lid]);
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if(GRID_SIZE >= 32)
-    {
-        if(lid < 16)
-        {
-            tmp_local[lid] = max(tmp_local[lid + 16], tmp_local[lid]);
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if(GRID_SIZE >= 16)
-    {
-        if(lid < 8)
-        {
-            tmp_local[lid] = max(tmp_local[lid + 8], tmp_local[lid]);
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if(GRID_SIZE >= 8)
-    {
-        if(lid < 4)
-        {
-            tmp_local[lid] = max(tmp_local[lid + 4], tmp_local[lid]);
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if(GRID_SIZE >= 4)
-    {
-        if(lid < 2)
-        {
-            tmp_local[lid] = max(tmp_local[lid + 2], tmp_local[lid]);
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if(lid == 0)
-    {
-        max_val_vec = max(tmp_local[lid + 1], tmp_local[lid]);
-        max_local   = MAX_REDUCE(max_val_vec, VECTOR_SIZE);
-    }
-    barrier(CLK_LOCAL_MEM_FENCE);
-
-    /* Second section */
-
-    // Set sum vector
-    VEC_TYPE  sum1D   = 0;
-    DATA_TYPE max_val = max_local;
-
-    // Shift values, exp and sum
-    for(i = 0; i < width; ++i)
-    {
-        VEC_TYPE data = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(src_addr + (i * GRID_SIZE * VECTOR_SIZE) * sizeof(DATA_TYPE)));
-        data -= max_val;
-#ifdef BETA
-        data *= beta;
-#endif /* BETA */
-#ifdef LOG_SOFTMAX
-        VSTORE(VECTOR_SIZE)
-        (data, 0, (__global DATA_TYPE *)(dst_addr + (i * GRID_SIZE * VECTOR_SIZE) * sizeof(DATA_TYPE)));
-        data = exp(data);
-#else  /* LOG_SOFTMAX */
-        data = exp(data);
-        VSTORE(VECTOR_SIZE)
-        (data, 0, (__global DATA_TYPE *)(dst_addr + (i * GRID_SIZE * VECTOR_SIZE) * sizeof(DATA_TYPE)));
-#endif /* LOG_SOFTMAX */
-        sum1D += data;
-    }
-#ifdef NON_MULTIPLE_OF_GRID_SIZE
-    boundary_workitems = (SRC_WIDTH % (GRID_SIZE * VECTOR_SIZE)) / VECTOR_SIZE;
-    if(lid < boundary_workitems)
-    {
-        VEC_TYPE data = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(__global DATA_TYPE *)(src_addr + (i * GRID_SIZE * VECTOR_SIZE) * sizeof(DATA_TYPE)));
-        data -= max_val;
-#ifdef BETA
-        data *= beta;
-#endif /* BETA */
-#ifdef LOG_SOFTMAX
-        VSTORE(VECTOR_SIZE)
-        (data, 0, (__global DATA_TYPE *)(dst_addr + (i * GRID_SIZE * VECTOR_SIZE) * sizeof(DATA_TYPE)));
-        data = exp(data);
-#else  /* LOG_SOFTMAX */
-        data = exp(data);
-        VSTORE(VECTOR_SIZE)
-        (data, 0, (__global DATA_TYPE *)(dst_addr + (i * GRID_SIZE * VECTOR_SIZE) * sizeof(DATA_TYPE)));
-#endif /* LOG_SOFTMAX */
-        sum1D += data;
-    }
-#ifdef NON_MULTIPLE_OF_VECTOR_SIZE
-    if(lid == 0)
-    {
-        // Handle non multiple of vector size ((GRID_SIZE * i * 4) + 4, 0); move 4 float positions ahead, *4 is due to the stride
-        VEC_TYPE data = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)(src_addr - VECTOR_SIZE_LEFTOVER * sizeof(DATA_TYPE)));
-        data -= max_val;
-#ifdef BETA
-        data *= beta;
-#endif /* BETA */
-#ifdef LOG_SOFTMAX
-        VSTORE_PARTIAL(VECTOR_SIZE, VECTOR_SIZE_LEFTOVER)
-        (data, 0, (__global DATA_TYPE *)(dst_addr - VECTOR_SIZE_LEFTOVER * sizeof(DATA_TYPE)));
-        data = exp(data);
-        data = select(0, data, widx);
-#else  /* LOG_SOFTMAX */
-        data = exp(data);
-        data = select(0, data, widx);
-        VSTORE_PARTIAL(VECTOR_SIZE, VECTOR_SIZE_LEFTOVER)
-        (data, 0, (__global DATA_TYPE *)(dst_addr - VECTOR_SIZE_LEFTOVER * sizeof(DATA_TYPE)));
-#endif /* LOG_SOFTMAX */
-        sum1D += data;
-    }
-#endif /* NON_MULTIPLE_OF_VECTOR_SIZE */
-#endif /* NON_MULTIPLE_OF_GRID_SIZE */
-    tmp_local[lid] = sum1D;
-
-    barrier(CLK_LOCAL_MEM_FENCE);
-
-    if(GRID_SIZE >= 256)
-    {
-        if(lid < 128)
-        {
-            tmp_local[lid] += tmp_local[lid + 128];
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if(GRID_SIZE >= 128)
-    {
-        if(lid < 64)
-        {
-            tmp_local[lid] += tmp_local[lid + 64];
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if(GRID_SIZE >= 64)
-    {
-        if(lid < 32)
-        {
-            tmp_local[lid] += tmp_local[lid + 32];
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if(GRID_SIZE >= 32)
-    {
-        if(lid < 16)
-        {
-            tmp_local[lid] += tmp_local[lid + 16];
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if(GRID_SIZE >= 16)
-    {
-        if(lid < 8)
-        {
-            tmp_local[lid] += tmp_local[lid + 8];
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if(GRID_SIZE >= 8)
-    {
-        if(lid < 4)
-        {
-            tmp_local[lid] += tmp_local[lid + 4];
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if(GRID_SIZE >= 4)
-    {
-        if(lid < 2)
-        {
-            tmp_local[lid] += tmp_local[lid + 2];
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if(lid == 0)
-    {
-        sum1D = (tmp_local[lid + 1] + tmp_local[lid]);
-        // Perform sum reduction
-        *((__global DATA_TYPE *)sum.ptr) = SUM_REDUCE(sum1D, VECTOR_SIZE);
-    }
-}
-
-#endif // defined(SRC_WIDTH) && defined(LOG_VECTOR_SIZE) && defined(MINVAL)
-#endif // defined(DATA_TYPE) && defined(MIN_VALUE) && defined(VECTOR_SIZE) && defined(VECTOR_SIZE_LEFTOVER)
\ No newline at end of file
+#undef MIN_VALUE_float
+#undef MIN_VALUE_half
+#undef MIN_VALUE_char
+#undef MIN_VALUE_uchar