blob: 53fb51557c5edbd21d3fd6af017d8023f3ec9dc4 [file] [log] [blame]
Anthony Barbier7068f992017-10-26 15:23:08 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25layout(local_size_x = LOCAL_SIZE_X, local_size_y = LOCAL_SIZE_Y, local_size_z = LOCAL_SIZE_Z) in;
26
zhenglin4f7f2552017-12-06 16:41:20 +080027#include "helpers_cs.h"
Anthony Barbier7068f992017-10-26 15:23:08 +010028
zhenglin4f7f2552017-12-06 16:41:20 +080029#if defined(DATA_TYPE_FP16)
Anthony Barbier7068f992017-10-26 15:23:08 +010030precision mediump float;
31#endif /*DATA_TYPE_FP32*/
32
33#define ADD_OP(a, b) ((a) + (b))
34#define SUB_OP(a, b) ((a) - (b))
35#define MUL_OP(a, b) ((a) * (b))
36#define INVSQRT_OP(a) inversesqrt((a))
37#define SQCVT_SAT(a) (a)
38
zhenglin4f7f2552017-12-06 16:41:20 +080039/** Apply batch normalization.
40 *
41 * @note The data type must be passed at compile time using "#define DATA_TYPE_NAME". e.g. "#define DATA_TYPE_FP32"
42 * @note Epsilon parameter in the batch normalization equation should be given as a preprocessor argument using "#define EPSILON". e.g. "#define EPSILON 0.1"
43 *
44 * @param[in] src_ptr Pointer to the first source tensor. Supported data types: F16/F32
45 * @param[in] src_attrs The attributes of the source tensor
46 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
47 * @param[in] dst_attrs The attributes of the destination tensor
48 * @param[in] mean_ptr Pointer to the mean source tensor. Supported data types: same as @p src_ptr
49 * @param[in] mean_attrs The attributes of the mean tensor
50 * @param[in] var_ptr Pointer to the var tensor. Supported data types: same as @p src_ptr
51 * @param[in] var_attrs The attributes of the var tensor
52 * @param[in] beta_ptr Pointer to the beta source tensor. Supported data types: same as @p src_ptr
53 * @param[in] beta_attrs The attributes of the beta tensor
54 * @param[in] gamma_ptr Pointer to the gamma source tensor. Supported data types: same as @p src_ptr
55 * @param[in] gamma_attrs The attributes of the gamma tensor
56 */
57SHADER_PARAMS_DECLARATION
Anthony Barbier7068f992017-10-26 15:23:08 +010058{
zhenglin4f7f2552017-12-06 16:41:20 +080059 Tensor3DAttributes src_attrs;
60 Tensor3DAttributes dst_attrs;
61 VectorAttributes mean_attrs;
62 VectorAttributes var_attrs;
63 VectorAttributes beta_attrs;
64 VectorAttributes gamma_attrs;
Anthony Barbier7068f992017-10-26 15:23:08 +010065};
66
67#ifdef DATA_TYPE_FP32
zhenglin4f7f2552017-12-06 16:41:20 +080068TENSOR_DECLARATION(1, srcBuffer, float, src_ptr, src_shift, 2, readonly);
69TENSOR_DECLARATION(2, dstBuffer, float, dst_ptr, dst_shift, 2, writeonly);
70TENSOR_DECLARATION(3, meanBuffer, float, mean_ptr, mean_shift, 2, readonly);
71TENSOR_DECLARATION(4, varBuffer, float, var_ptr, var_shift, 2, readonly);
72TENSOR_DECLARATION(5, betaBuffer, float, beta_ptr, beta_shift, 2, readonly);
73TENSOR_DECLARATION(6, gammaBuffer, float, gamma_ptr, gamma_shift, 2, readonly);
Anthony Barbier7068f992017-10-26 15:23:08 +010074
Anthony Barbier7068f992017-10-26 15:23:08 +010075void main(void)
76{
zhenglin4f7f2552017-12-06 16:41:20 +080077 Tensor3DIterator src_iter = CONVERT_TO_TENSOR3D_ITERATOR(src_attrs, src_shift);
78 Tensor3DIterator dst_iter = CONVERT_TO_TENSOR3D_ITERATOR(dst_attrs, dst_shift);
79 VectorIterator mean_iter = CONVERT_TO_VECTOR_ITERATOR(mean_attrs, mean_shift);
80 VectorIterator var_iter = CONVERT_TO_VECTOR_ITERATOR(var_attrs, var_shift);
81 VectorIterator beta_iter = CONVERT_TO_VECTOR_ITERATOR(beta_attrs, beta_shift);
82 VectorIterator gamma_iter = CONVERT_TO_VECTOR_ITERATOR(gamma_attrs, gamma_shift);
Anthony Barbier7068f992017-10-26 15:23:08 +010083
84 float input_value = 0.f;
85 float denominator = 0.f;
86 float numerator = 0.f;
87 float x_bar = 0.f;
88 float gamma_param = 0.f;
89 float beta_param = 0.f;
90
91 uint current_slice = gl_GlobalInvocationID.z;
92
zhenglin4f7f2552017-12-06 16:41:20 +080093 input_value = LOAD_CURRENT_ITEM(src_ptr, src_iter);
94 denominator = LOAD(var_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(var_iter, current_slice * var_attrs.stride_x));
Anthony Barbier7068f992017-10-26 15:23:08 +010095 denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
96
97 // Calculate x bar and store results
zhenglin4f7f2552017-12-06 16:41:20 +080098 numerator = LOAD(mean_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(mean_iter, current_slice * mean_attrs.stride_x));
Anthony Barbier7068f992017-10-26 15:23:08 +010099 numerator = SUB_OP(input_value, numerator);
100 x_bar = MUL_OP(numerator, denominator);
101
zhenglin4f7f2552017-12-06 16:41:20 +0800102 gamma_param = LOAD(gamma_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(gamma_iter, current_slice * beta_attrs.stride_x));
103 beta_param = LOAD(beta_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(beta_iter, current_slice * beta_attrs.stride_x));
Anthony Barbier7068f992017-10-26 15:23:08 +0100104
zhenglin4f7f2552017-12-06 16:41:20 +0800105 STORE_CURRENT_ITEM(dst_ptr, dst_iter, ADD_OP(MUL_OP(gamma_param, x_bar), beta_param));
Anthony Barbier7068f992017-10-26 15:23:08 +0100106}
107
108#elif defined(DATA_TYPE_FP16)
zhenglin4f7f2552017-12-06 16:41:20 +0800109TENSOR_DECLARATION(1, srcBuffer, uvec2, src_ptr, src_shift, 3, readonly);
110TENSOR_DECLARATION(2, dstBuffer, uvec2, dst_ptr, dst_shift, 3, writeonly);
111TENSOR_DECLARATION(3, meanBuffer, uvec2, mean_ptr, mean_shift, 3, readonly);
112TENSOR_DECLARATION(4, varBuffer, uvec2, var_ptr, var_shift, 3, readonly);
113TENSOR_DECLARATION(5, betaBuffer, uvec2, beta_ptr, beta_shift, 3, readonly);
114TENSOR_DECLARATION(6, gammaBuffer, uvec2, gamma_ptr, gamma_shift, 3, readonly);
Anthony Barbier7068f992017-10-26 15:23:08 +0100115
Anthony Barbier7068f992017-10-26 15:23:08 +0100116void main(void)
117{
zhenglin4f7f2552017-12-06 16:41:20 +0800118 Tensor3DIterator src_iter = CONVERT_TO_TENSOR3D_ITERATOR(src_attrs, src_shift);
119 Tensor3DIterator dst_iter = CONVERT_TO_TENSOR3D_ITERATOR(dst_attrs, dst_shift);
120 VectorIterator mean_iter = CONVERT_TO_VECTOR_ITERATOR(mean_attrs, mean_shift);
121 VectorIterator var_iter = CONVERT_TO_VECTOR_ITERATOR(var_attrs, var_shift);
122 VectorIterator beta_iter = CONVERT_TO_VECTOR_ITERATOR(beta_attrs, beta_shift);
123 VectorIterator gamma_iter = CONVERT_TO_VECTOR_ITERATOR(gamma_attrs, gamma_shift);
Anthony Barbier7068f992017-10-26 15:23:08 +0100124
zhenglin923241e2017-12-05 11:30:51 +0800125 vec4 unpacked_s[5];
Anthony Barbier7068f992017-10-26 15:23:08 +0100126 float denominator;
127 float numerator;
Anthony Barbier7068f992017-10-26 15:23:08 +0100128 float gamma_param;
129 float beta_param;
zhenglin923241e2017-12-05 11:30:51 +0800130 vec4 x_bar;
131 vec4 result;
Anthony Barbier7068f992017-10-26 15:23:08 +0100132
133 uint current_slice = gl_GlobalInvocationID.z;
zhenglin4f7f2552017-12-06 16:41:20 +0800134 unpacked_s[0] = LOAD_UNPACK4_CURRENT_ITEM_HALF(src_ptr, src_iter);
135 unpacked_s[1] = LOAD_UNPACK4_HALF(var_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(var_iter, current_slice * var_attrs.stride_x));
136 unpacked_s[2] = LOAD_UNPACK4_HALF(mean_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(mean_iter, current_slice * mean_attrs.stride_x));
137 unpacked_s[3] = LOAD_UNPACK4_HALF(gamma_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(gamma_iter, current_slice * beta_attrs.stride_x));
138 unpacked_s[4] = LOAD_UNPACK4_HALF(beta_ptr, TENSOR_OFFSET_ADVANCE_IN_BYTES(beta_iter, current_slice * beta_attrs.stride_x));
zhenglin923241e2017-12-05 11:30:51 +0800139
140 if((current_slice % uint(4)) == uint(0))
Anthony Barbier7068f992017-10-26 15:23:08 +0100141 {
zhenglin923241e2017-12-05 11:30:51 +0800142 denominator = unpacked_s[1].x;
Anthony Barbier7068f992017-10-26 15:23:08 +0100143 denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
144
145 //Calculate x bar and store results
zhenglin923241e2017-12-05 11:30:51 +0800146 numerator = unpacked_s[2].x;
147 x_bar = MUL_OP(SUB_OP(unpacked_s[0], numerator), denominator);
Anthony Barbier7068f992017-10-26 15:23:08 +0100148
zhenglin923241e2017-12-05 11:30:51 +0800149 gamma_param = unpacked_s[3].x;
150 beta_param = unpacked_s[4].x;
151 result = ADD_OP(MUL_OP(gamma_param, x_bar), beta_param);
Anthony Barbier7068f992017-10-26 15:23:08 +0100152
zhenglin4f7f2552017-12-06 16:41:20 +0800153 STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
zhenglin923241e2017-12-05 11:30:51 +0800154 }
155 else if((current_slice % uint(4)) == uint(1))
156 {
157 denominator = unpacked_s[1].y;
158 denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
159
160 //Calculate x bar and store results
161 numerator = unpacked_s[2].y;
162 x_bar = MUL_OP(SUB_OP(unpacked_s[0], numerator), denominator);
163
164 gamma_param = unpacked_s[3].y;
165 beta_param = unpacked_s[4].y;
166 result = ADD_OP(MUL_OP(gamma_param, x_bar), beta_param);
167
zhenglin4f7f2552017-12-06 16:41:20 +0800168 STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
zhenglin923241e2017-12-05 11:30:51 +0800169 }
170 else if((current_slice % uint(4)) == uint(2))
171 {
172 denominator = unpacked_s[1].z;
173 denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
174
175 //Calculate x bar and store results
176 numerator = unpacked_s[2].z;
177 x_bar = MUL_OP(SUB_OP(unpacked_s[0], numerator), denominator);
178
179 gamma_param = unpacked_s[3].z;
180 beta_param = unpacked_s[4].z;
181 result = ADD_OP(MUL_OP(gamma_param, x_bar), beta_param);
182
zhenglin4f7f2552017-12-06 16:41:20 +0800183 STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
Anthony Barbier7068f992017-10-26 15:23:08 +0100184 }
185 else
186 {
zhenglin923241e2017-12-05 11:30:51 +0800187 denominator = unpacked_s[1].w;
Anthony Barbier7068f992017-10-26 15:23:08 +0100188 denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
189
190 //Calculate x bar and store results
zhenglin923241e2017-12-05 11:30:51 +0800191 numerator = unpacked_s[2].w;
192 x_bar = MUL_OP(SUB_OP(unpacked_s[0], numerator), denominator);
Anthony Barbier7068f992017-10-26 15:23:08 +0100193
zhenglin923241e2017-12-05 11:30:51 +0800194 gamma_param = unpacked_s[3].w;
195 beta_param = unpacked_s[4].w;
196 result = ADD_OP(MUL_OP(gamma_param, x_bar), beta_param);
Anthony Barbier7068f992017-10-26 15:23:08 +0100197
zhenglin4f7f2552017-12-06 16:41:20 +0800198 STORE_PACK4_CURRENT_ITEM_HALF(dst_ptr, dst_iter, result);
Anthony Barbier7068f992017-10-26 15:23:08 +0100199 }
200}
zhenglin923241e2017-12-05 11:30:51 +0800201#endif /*DATA_TYPE_FP16*/