blob: 54880926ccf05ccc79964746e244bc598dac7550 [file] [log] [blame]
Anthony Barbier7068f992017-10-26 15:23:08 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25layout(local_size_x = LOCAL_SIZE_X, local_size_y = LOCAL_SIZE_Y, local_size_z = LOCAL_SIZE_Z) in;
26
27#include "helpers.h"
28
29#ifdef DATA_TYPE_FP32
30precision highp float;
31#elif defined(DATA_TYPE_FP16)
32precision mediump float;
33#endif /*DATA_TYPE_FP32*/
34
35#define ADD_OP(a, b) ((a) + (b))
36#define SUB_OP(a, b) ((a) - (b))
37#define MUL_OP(a, b) ((a) * (b))
38#define INVSQRT_OP(a) inversesqrt((a))
39#define SQCVT_SAT(a) (a)
40
41layout(std140) uniform shader_params
42{
43 TENSOR3D_PARAM_DECLARATION(src);
44 TENSOR3D_PARAM_DECLARATION(dst);
45 VECTOR_PARAM_DECLARATION(mean);
46 VECTOR_PARAM_DECLARATION(var);
47 VECTOR_PARAM_DECLARATION(beta);
48 VECTOR_PARAM_DECLARATION(gamma);
49};
50
51#ifdef DATA_TYPE_FP32
52BUFFER_DECLARATION(src, 1, float, readonly);
53BUFFER_DECLARATION(dst, 2, float, writeonly);
54BUFFER_DECLARATION(mean, 3, float, readonly);
55BUFFER_DECLARATION(var, 4, float, readonly);
56BUFFER_DECLARATION(beta, 5, float, readonly);
57BUFFER_DECLARATION(gamma, 6, float, readonly);
58
59/** Apply batch normalization.
60 *
61 * @note Epsilon parameter in the batch normalization equation should be given as a preprocessor argument using "#define EPSILON". e.g. "#define EPSILON 0.1"
62 *
63 * @param[in] src_ptr Pointer to the first source tensor. Supported data types: F32
64 * @param[in] src_stride_x Stride of the first source tensor in X dimension (in bytes)
65 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
66 * @param[in] src_stride_y Stride of the first source tensor in Y dimension (in bytes)
67 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
68 * @param[in] src_stride_z Stride of the first source tensor in Z dimension (in bytes)
69 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
70 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the first source tensor
71 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
72 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
73 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
74 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
75 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
76 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
77 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
78 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
79 * @param[in] mean_ptr Pointer to the mean source tensor. Supported data types: same as @p src_ptr
80 * @param[in] mean_stride_x Stride of the mean source tensor in X dimension (in bytes)
81 * @param[in] mean_step_x mean_stride_x * number of elements along X processed per workitem(in bytes)
82 * @param[in] mean_offset_first_element_in_bytes The offset of the first element in the mean source tensor
83 * @param[in] var_ptr Pointer to the var tensor. Supported data types: same as @p src_ptr
84 * @param[in] var_stride_x Stride of the var tensor in X dimension (in bytes)
85 * @param[in] var_step_x var_stride_x * number of elements along X processed per workitem(in bytes)
86 * @param[in] var_offset_first_element_in_bytes The offset of the first element in the var source tensor
87 * @param[in] beta_ptr Pointer to the beta source tensor. Supported data types: same as @p src_ptr
88 * @param[in] beta_stride_x Stride of the beta source tensor in X dimension (in bytes)
89 * @param[in] beta_step_x beta_stride_x * number of elements along X processed per workitem(in bytes)
90 * @param[in] beta_offset_first_element_in_bytes The offset of the first element in the beta source tensor
91 * @param[in] gamma_ptr Pointer to the gamma source tensor. Supported data types: same as @p src_ptr
92 * @param[in] gamma_stride_x Stride of the gamma source tensor in X dimension (in bytes)
93 * @param[in] gamma_step_x gamma_stride_x * number of elements along X processed per workitem(in bytes)
94 * @param[in] gamma_offset_first_element_in_bytes The offset of the first element in the gamma source tensor
95 */
96void main(void)
97{
98 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
99 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
100 Vector mean = CONVERT_TO_VECTOR_STRUCT(mean);
101 Vector var = CONVERT_TO_VECTOR_STRUCT(var);
102 Vector beta = CONVERT_TO_VECTOR_STRUCT(beta);
103 Vector gamma = CONVERT_TO_VECTOR_STRUCT(gamma);
104
105 float input_value = 0.f;
106 float denominator = 0.f;
107 float numerator = 0.f;
108 float x_bar = 0.f;
109 float gamma_param = 0.f;
110 float beta_param = 0.f;
111
112 uint current_slice = gl_GlobalInvocationID.z;
113
114 input_value = src_ptr[src.current_offset];
115 denominator = var_ptr[var.current_offset + (current_slice * var.stride_x) >> 2];
116 denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
117
118 // Calculate x bar and store results
119 numerator = mean_ptr[mean.current_offset + (current_slice * mean.stride_x) >> 2];
120 numerator = SUB_OP(input_value, numerator);
121 x_bar = MUL_OP(numerator, denominator);
122
123 gamma_param = gamma_ptr[gamma.current_offset + (current_slice * beta.stride_x) >> 2];
124 beta_param = beta_ptr[beta.current_offset + (current_slice * beta.stride_x) >> 2];
125
126 dst_ptr[dst.current_offset] = ADD_OP(MUL_OP(gamma_param, x_bar), beta_param);
127}
128
129#elif defined(DATA_TYPE_FP16)
130BUFFER_DECLARATION(src, 1, uint, );
131BUFFER_DECLARATION(dst, 2, uint, writeonly);
132BUFFER_DECLARATION(mean, 3, uint, );
133BUFFER_DECLARATION(var, 4, uint, );
134BUFFER_DECLARATION(beta, 5, uint, );
135BUFFER_DECLARATION(gamma, 6, uint, );
136
137/** Apply batch normalization.
138 *
139 * @note Epsilon parameter in the batch normalization equation should be given as a preprocessor argument using "#define EPSILON". e.g. "#define EPSILON 0.1"
140 *
141 * @param[in] src_ptr Pointer to the first source tensor. Supported data types: F16
142 * @param[in] src_stride_x Stride of the first source tensor in X dimension (in bytes)
143 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
144 * @param[in] src_stride_y Stride of the first source tensor in Y dimension (in bytes)
145 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
146 * @param[in] src_stride_z Stride of the first source tensor in Z dimension (in bytes)
147 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
148 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the first source tensor
149 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
150 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
151 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
152 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
153 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
154 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
155 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
156 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
157 * @param[in] mean_ptr Pointer to the mean source tensor. Supported data types: same as @p src_ptr
158 * @param[in] mean_stride_x Stride of the mean source tensor in X dimension (in bytes)
159 * @param[in] mean_step_x mean_stride_x * number of elements along X processed per workitem(in bytes)
160 * @param[in] mean_offset_first_element_in_bytes The offset of the first element in the mean source tensor
161 * @param[in] var_ptr Pointer to the var tensor. Supported data types: same as @p src_ptr
162 * @param[in] var_stride_x Stride of the var tensor in X dimension (in bytes)
163 * @param[in] var_step_x var_stride_x * number of elements along X processed per workitem(in bytes)
164 * @param[in] var_offset_first_element_in_bytes The offset of the first element in the var source tensor
165 * @param[in] beta_ptr Pointer to the beta source tensor. Supported data types: same as @p src_ptr
166 * @param[in] beta_stride_x Stride of the beta source tensor in X dimension (in bytes)
167 * @param[in] beta_step_x beta_stride_x * number of elements along X processed per workitem(in bytes)
168 * @param[in] beta_offset_first_element_in_bytes The offset of the first element in the beta source tensor
169 * @param[in] gamma_ptr Pointer to the gamma source tensor. Supported data types: same as @p src_ptr
170 * @param[in] gamma_stride_x Stride of the gamma source tensor in X dimension (in bytes)
171 * @param[in] gamma_step_x gamma_stride_x * number of elements along X processed per workitem(in bytes)
172 * @param[in] gamma_offset_first_element_in_bytes The offset of the first element in the gamma source tensor
173 */
174void main(void)
175{
176 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT_FP16(src);
177 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT_FP16(dst);
178 Vector mean = CONVERT_TO_VECTOR_STRUCT_FP16(mean);
179 Vector var = CONVERT_TO_VECTOR_STRUCT_FP16(var);
180 Vector beta = CONVERT_TO_VECTOR_STRUCT_FP16(beta);
181 Vector gamma = CONVERT_TO_VECTOR_STRUCT_FP16(gamma);
182
183 vec2 input_value;
184 float denominator;
185 float numerator;
186 vec2 x_bar;
187 float gamma_param;
188 float beta_param;
189
190 uint current_slice = gl_GlobalInvocationID.z;
191 if((current_slice % uint(2)) == uint(0))
192 {
193 input_value = unpackHalf2x16(src_ptr[src.current_offset >> 2]);
194 denominator = unpackHalf2x16(var_ptr[(var.current_offset + current_slice * var.stride_x) >> 2]).x;
195 denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
196
197 //Calculate x bar and store results
198 numerator = unpackHalf2x16(mean_ptr[(mean.current_offset + current_slice * mean.stride_x) >> 2]).x;
199 x_bar = MUL_OP(SUB_OP(input_value, numerator), denominator);
200
201 gamma_param = unpackHalf2x16(gamma_ptr[(gamma.current_offset + current_slice * beta.stride_x) >> 2]).x;
202 beta_param = unpackHalf2x16(beta_ptr[(beta.current_offset + current_slice * beta.stride_x) >> 2]).x;
203
204 dst_ptr[dst.current_offset >> 2] = packHalf2x16(ADD_OP(MUL_OP(gamma_param, x_bar), beta_param));
205 }
206 else
207 {
208 input_value = unpackHalf2x16(src_ptr[src.current_offset >> 2]);
209 denominator = unpackHalf2x16(var_ptr[(var.current_offset + current_slice * var.stride_x) >> 2]).y;
210 denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(float(ESPILON))));
211
212 //Calculate x bar and store results
213 numerator = unpackHalf2x16(mean_ptr[(mean.current_offset + current_slice * mean.stride_x) >> 2]).y;
214 x_bar = MUL_OP(SUB_OP(input_value, numerator), denominator);
215
216 gamma_param = unpackHalf2x16(gamma_ptr[(gamma.current_offset + current_slice * beta.stride_x) >> 2]).y;
217 beta_param = unpackHalf2x16(beta_ptr[(beta.current_offset + current_slice * beta.stride_x) >> 2]).y;
218
219 dst_ptr[dst.current_offset >> 2] = packHalf2x16(ADD_OP(MUL_OP(gamma_param, x_bar), beta_param));
220 }
221}
222#endif /*DATA_TYPE_FP32*/