blob: 7629b255b73a482ab886d7ab4f720cbe5af20ee2 [file] [log] [blame]
Anthony Barbier7068f992017-10-26 15:23:08 +01001/*
Giorgio Arena11674872018-02-07 15:38:12 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier7068f992017-10-26 15:23:08 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25layout(local_size_x = LOCAL_SIZE_X, local_size_y = LOCAL_SIZE_Y, local_size_z = LOCAL_SIZE_Z) in;
26
zhenglin4f7f2552017-12-06 16:41:20 +080027#include "helpers_cs.h"
Anthony Barbier7068f992017-10-26 15:23:08 +010028
zhenglin4f7f2552017-12-06 16:41:20 +080029#if defined(DATA_TYPE_FP16)
Anthony Barbier7068f992017-10-26 15:23:08 +010030precision mediump float;
31#endif /*DATA_TYPE_FP32*/
32
33#define ADD_OP(a, b) ((a) + (b))
34#define SUB_OP(a, b) ((a) - (b))
35#define MUL_OP(a, b) ((a) * (b))
36#define INVSQRT_OP(a) inversesqrt((a))
37#define SQCVT_SAT(a) (a)
38
Giorgio Arena11674872018-02-07 15:38:12 +000039#if defined(LU_BRELU)
40#define ACTIVATION_FUNC(x) min(max(x, float(B_VAL)), float(A_VAL))
41#elif defined(BRELU)
42#define ACTIVATION_FUNC(x) min(max(x, float(0)), float(A_VAL))
43#elif defined(RELU)
44#define ACTIVATION_FUNC(x) max(x, float(0))
45#else /* defined(FUSED_ACT) */
46#define ACTIVATION_FUNC(x) (x)
47#endif /* defined(FUSED_ACT) */
48
zhenglin4f7f2552017-12-06 16:41:20 +080049/** Apply batch normalization.
50 *
51 * @note The data type must be passed at compile time using "#define DATA_TYPE_NAME". e.g. "#define DATA_TYPE_FP32"
52 * @note Epsilon parameter in the batch normalization equation should be given as a preprocessor argument using "#define EPSILON". e.g. "#define EPSILON 0.1"
53 *
54 * @param[in] src_ptr Pointer to the first source tensor. Supported data types: F16/F32
55 * @param[in] src_attrs The attributes of the source tensor
56 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
57 * @param[in] dst_attrs The attributes of the destination tensor
58 * @param[in] mean_ptr Pointer to the mean source tensor. Supported data types: same as @p src_ptr
59 * @param[in] mean_attrs The attributes of the mean tensor
60 * @param[in] var_ptr Pointer to the var tensor. Supported data types: same as @p src_ptr
61 * @param[in] var_attrs The attributes of the var tensor
62 * @param[in] beta_ptr Pointer to the beta source tensor. Supported data types: same as @p src_ptr
63 * @param[in] beta_attrs The attributes of the beta tensor
64 * @param[in] gamma_ptr Pointer to the gamma source tensor. Supported data types: same as @p src_ptr
65 * @param[in] gamma_attrs The attributes of the gamma tensor
66 */
67SHADER_PARAMS_DECLARATION
Anthony Barbier7068f992017-10-26 15:23:08 +010068{
zhenglin4f7f2552017-12-06 16:41:20 +080069 Tensor3DAttributes src_attrs;
70 Tensor3DAttributes dst_attrs;
71 VectorAttributes mean_attrs;
72 VectorAttributes var_attrs;
73 VectorAttributes beta_attrs;
74 VectorAttributes gamma_attrs;
Anthony Barbier7068f992017-10-26 15:23:08 +010075};
76
77#ifdef DATA_TYPE_FP32
zhenglin4f7f2552017-12-06 16:41:20 +080078TENSOR_DECLARATION(1, srcBuffer, float, src_ptr, src_shift, 2, readonly);
79TENSOR_DECLARATION(2, dstBuffer, float, dst_ptr, dst_shift, 2, writeonly);
80TENSOR_DECLARATION(3, meanBuffer, float, mean_ptr, mean_shift, 2, readonly);
81TENSOR_DECLARATION(4, varBuffer, float, var_ptr, var_shift, 2, readonly);
82TENSOR_DECLARATION(5, betaBuffer, float, beta_ptr, beta_shift, 2, readonly);
83TENSOR_DECLARATION(6, gammaBuffer, float, gamma_ptr, gamma_shift, 2, readonly);
Anthony Barbier7068f992017-10-26 15:23:08 +010084
Anthony Barbier7068f992017-10-26 15:23:08 +010085void main(void)
86{
zhenglin4f7f2552017-12-06 16:41:20 +080087 Tensor3DIterator src_iter = CONVERT_TO_TENSOR3D_ITERATOR(src_attrs, src_shift);
88 Tensor3DIterator dst_iter = CONVERT_TO_TENSOR3D_ITERATOR(dst_attrs, dst_shift);
89 VectorIterator mean_iter = CONVERT_TO_VECTOR_ITERATOR(mean_attrs, mean_shift);
90 VectorIterator var_iter = CONVERT_TO_VECTOR_ITERATOR(var_attrs, var_shift);
91 VectorIterator beta_iter = CONVERT_TO_VECTOR_ITERATOR(beta_attrs, beta_shift);
92 VectorIterator gamma_iter = CONVERT_TO_VECTOR_ITERATOR(gamma_attrs, gamma_shift);
Anthony Barbier7068f992017-10-26 15:23:08 +010093
94 float input_value = 0.f;
95 float denominator = 0.f;
96 float numerator = 0.f;
97 float x_bar = 0.f;
98 float gamma_param = 0.f;
99 float beta_param = 0.f;
100
101 uint current_slice = gl_GlobalInvocationID.z;
102
zhenglin4f7f2552017-12-06 16:41:20 +0800103 input_value = LOAD_CURRENT_ITEM(src_ptr, src_iter);
104 denominator = LOAD(var_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(var_iter, current_slice * var_attrs.stride_x));
Anthony Barbier7068f992017-10-26 15:23:08 +0100105 denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
106
107 // Calculate x bar and store results
zhenglin4f7f2552017-12-06 16:41:20 +0800108 numerator = LOAD(mean_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(mean_iter, current_slice * mean_attrs.stride_x));
Anthony Barbier7068f992017-10-26 15:23:08 +0100109 numerator = SUB_OP(input_value, numerator);
110 x_bar = MUL_OP(numerator, denominator);
111
zhenglin4f7f2552017-12-06 16:41:20 +0800112 gamma_param = LOAD(gamma_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(gamma_iter, current_slice * beta_attrs.stride_x));
113 beta_param = LOAD(beta_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(beta_iter, current_slice * beta_attrs.stride_x));
Anthony Barbier7068f992017-10-26 15:23:08 +0100114
Giorgio Arena11674872018-02-07 15:38:12 +0000115 STORE_CURRENT_ITEM(dst_ptr, dst_iter, ACTIVATION_FUNC(ADD_OP(MUL_OP(gamma_param, x_bar), beta_param)));
Anthony Barbier7068f992017-10-26 15:23:08 +0100116}
117
118#elif defined(DATA_TYPE_FP16)
zhenglin4f7f2552017-12-06 16:41:20 +0800119TENSOR_DECLARATION(1, srcBuffer, uvec2, src_ptr, src_shift, 3, readonly);
120TENSOR_DECLARATION(2, dstBuffer, uvec2, dst_ptr, dst_shift, 3, writeonly);
121TENSOR_DECLARATION(3, meanBuffer, uvec2, mean_ptr, mean_shift, 3, readonly);
122TENSOR_DECLARATION(4, varBuffer, uvec2, var_ptr, var_shift, 3, readonly);
123TENSOR_DECLARATION(5, betaBuffer, uvec2, beta_ptr, beta_shift, 3, readonly);
124TENSOR_DECLARATION(6, gammaBuffer, uvec2, gamma_ptr, gamma_shift, 3, readonly);
Anthony Barbier7068f992017-10-26 15:23:08 +0100125
Anthony Barbier7068f992017-10-26 15:23:08 +0100126void main(void)
127{
zhenglin4f7f2552017-12-06 16:41:20 +0800128 Tensor3DIterator src_iter = CONVERT_TO_TENSOR3D_ITERATOR(src_attrs, src_shift);
129 Tensor3DIterator dst_iter = CONVERT_TO_TENSOR3D_ITERATOR(dst_attrs, dst_shift);
130 VectorIterator mean_iter = CONVERT_TO_VECTOR_ITERATOR(mean_attrs, mean_shift);
131 VectorIterator var_iter = CONVERT_TO_VECTOR_ITERATOR(var_attrs, var_shift);
132 VectorIterator beta_iter = CONVERT_TO_VECTOR_ITERATOR(beta_attrs, beta_shift);
133 VectorIterator gamma_iter = CONVERT_TO_VECTOR_ITERATOR(gamma_attrs, gamma_shift);
Anthony Barbier7068f992017-10-26 15:23:08 +0100134
zhenglin923241e2017-12-05 11:30:51 +0800135 vec4 unpacked_s[5];
Anthony Barbier7068f992017-10-26 15:23:08 +0100136 float denominator;
137 float numerator;
Anthony Barbier7068f992017-10-26 15:23:08 +0100138 float gamma_param;
139 float beta_param;
zhenglin923241e2017-12-05 11:30:51 +0800140 vec4 x_bar;
141 vec4 result;
Anthony Barbier7068f992017-10-26 15:23:08 +0100142
143 uint current_slice = gl_GlobalInvocationID.z;
zhenglin4f7f2552017-12-06 16:41:20 +0800144 unpacked_s[0] = LOAD_UNPACK4_CURRENT_ITEM_HALF(src_ptr, src_iter);
145 unpacked_s[1] = LOAD_UNPACK4_HALF(var_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(var_iter, current_slice * var_attrs.stride_x));
146 unpacked_s[2] = LOAD_UNPACK4_HALF(mean_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(mean_iter, current_slice * mean_attrs.stride_x));
147 unpacked_s[3] = LOAD_UNPACK4_HALF(gamma_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(gamma_iter, current_slice * beta_attrs.stride_x));
148 unpacked_s[4] = LOAD_UNPACK4_HALF(beta_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(beta_iter, current_slice * beta_attrs.stride_x));
zhenglin923241e2017-12-05 11:30:51 +0800149
150 if((current_slice % uint(4)) == uint(0))
Anthony Barbier7068f992017-10-26 15:23:08 +0100151 {
zhenglin923241e2017-12-05 11:30:51 +0800152 denominator = unpacked_s[1].x;
Anthony Barbier7068f992017-10-26 15:23:08 +0100153 denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
154
155 //Calculate x bar and store results
zhenglin923241e2017-12-05 11:30:51 +0800156 numerator = unpacked_s[2].x;
157 x_bar = MUL_OP(SUB_OP(unpacked_s[0], numerator), denominator);
Anthony Barbier7068f992017-10-26 15:23:08 +0100158
zhenglin923241e2017-12-05 11:30:51 +0800159 gamma_param = unpacked_s[3].x;
160 beta_param = unpacked_s[4].x;
Giorgio Arena11674872018-02-07 15:38:12 +0000161 result = ACTIVATION_FUNC(ADD_OP(MUL_OP(gamma_param, x_bar), beta_param));
Anthony Barbier7068f992017-10-26 15:23:08 +0100162
zhenglin4f7f2552017-12-06 16:41:20 +0800163 STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
zhenglin923241e2017-12-05 11:30:51 +0800164 }
165 else if((current_slice % uint(4)) == uint(1))
166 {
167 denominator = unpacked_s[1].y;
168 denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
169
170 //Calculate x bar and store results
171 numerator = unpacked_s[2].y;
172 x_bar = MUL_OP(SUB_OP(unpacked_s[0], numerator), denominator);
173
174 gamma_param = unpacked_s[3].y;
175 beta_param = unpacked_s[4].y;
Giorgio Arena11674872018-02-07 15:38:12 +0000176 result = ACTIVATION_FUNC(ADD_OP(MUL_OP(gamma_param, x_bar), beta_param));
zhenglin923241e2017-12-05 11:30:51 +0800177
zhenglin4f7f2552017-12-06 16:41:20 +0800178 STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
zhenglin923241e2017-12-05 11:30:51 +0800179 }
180 else if((current_slice % uint(4)) == uint(2))
181 {
182 denominator = unpacked_s[1].z;
183 denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
184
185 //Calculate x bar and store results
186 numerator = unpacked_s[2].z;
187 x_bar = MUL_OP(SUB_OP(unpacked_s[0], numerator), denominator);
188
189 gamma_param = unpacked_s[3].z;
190 beta_param = unpacked_s[4].z;
Giorgio Arena11674872018-02-07 15:38:12 +0000191 result = ACTIVATION_FUNC(ADD_OP(MUL_OP(gamma_param, x_bar), beta_param));
zhenglin923241e2017-12-05 11:30:51 +0800192
zhenglin4f7f2552017-12-06 16:41:20 +0800193 STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
Anthony Barbier7068f992017-10-26 15:23:08 +0100194 }
195 else
196 {
zhenglin923241e2017-12-05 11:30:51 +0800197 denominator = unpacked_s[1].w;
Anthony Barbier7068f992017-10-26 15:23:08 +0100198 denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
199
200 //Calculate x bar and store results
zhenglin923241e2017-12-05 11:30:51 +0800201 numerator = unpacked_s[2].w;
202 x_bar = MUL_OP(SUB_OP(unpacked_s[0], numerator), denominator);
Anthony Barbier7068f992017-10-26 15:23:08 +0100203
zhenglin923241e2017-12-05 11:30:51 +0800204 gamma_param = unpacked_s[3].w;
205 beta_param = unpacked_s[4].w;
Giorgio Arena11674872018-02-07 15:38:12 +0000206 result = ACTIVATION_FUNC(ADD_OP(MUL_OP(gamma_param, x_bar), beta_param));
Anthony Barbier7068f992017-10-26 15:23:08 +0100207
zhenglin4f7f2552017-12-06 16:41:20 +0800208 STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
Anthony Barbier7068f992017-10-26 15:23:08 +0100209 }
210}
zhenglin923241e2017-12-05 11:30:51 +0800211#endif /*DATA_TYPE_FP16*/