blob: a8d8d32b122da92d5ffef14aaa049132d195e935 [file] [log] [blame]
Gunes Bayiraecb5d92022-12-18 21:31:29 +00001/*
Ramy Elgammal002e6532023-01-11 18:48:04 +00002 * Copyright (c) 2022-2023 Arm Limited.
Gunes Bayiraecb5d92022-12-18 21:31:29 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplateLogits1DMaxShiftExpSum.h"
26
Matthew Bentham314d3e22023-06-23 10:53:52 +000027#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
28#include "arm_compute/core/utils/StringUtils.h"
Gunes Bayiraecb5d92022-12-18 21:31:29 +000029#include "src/core/helpers/WindowHelpers.h"
30#include "src/dynamic_fusion/sketch/gpu/GpuKernelComponentGroup.h"
31#include "support/StringSupport.h"
32
33namespace arm_compute
34{
35namespace experimental
36{
37namespace dynamic_fusion
38{
Ramy Elgammal002e6532023-01-11 18:48:04 +000039namespace
40{
41 constexpr unsigned int serial_vector_size = 8;
42} // namespace
Gunes Bayiraecb5d92022-12-18 21:31:29 +000043ClTemplateLogits1DMaxShiftExpSum::ClTemplateLogits1DMaxShiftExpSum(ComponentId id,
44 const ArgumentPack<ITensorInfo> &tensors,
45 const Attributes &attributes)
46 : IGpuTemplateComponentWriter{ id, tensors },
47 _src{},
48 _sum{},
49 _dst{},
50 _attributes{ attributes }
51{
52 _src = this->tensors().get_const_tensor(TensorType::ACL_SRC_0);
53 _sum = this->tensors().get_const_tensor(TensorType::ACL_DST_0);
54 _dst = this->tensors().get_const_tensor(TensorType::ACL_DST_1);
Ramy Elgammal002e6532023-01-11 18:48:04 +000055 ARM_COMPUTE_ERROR_ON_NULLPTR(_src);
56 ARM_COMPUTE_ERROR_ON_NULLPTR(_sum);
57 ARM_COMPUTE_ERROR_ON_NULLPTR(_dst);
Gunes Bayiraecb5d92022-12-18 21:31:29 +000058}
59
60std::string ClTemplateLogits1DMaxShiftExpSum::get_name() const
61{
62 return "logits_1d_max_shift_exp_sum";
63}
64
65std::string ClTemplateLogits1DMaxShiftExpSum::get_component_code(const ComponentGroup &comp_group) const
66{
67 ARM_COMPUTE_UNUSED(comp_group);
68
69 std::string code = R"_(
70//------------------ START KERNEL {{meta_kernel_id}} ---------------------
71#define VEC_TYPE VEC_DATA_TYPE({{DATA_TYPE}}, N0)
72#define SELECT_TYPE SELECT_VEC_DATA_TYPE({{DATA_TYPE}}, N0)
73{
74 __global uchar *src_addr = {{src}}_ptr + {{src}}_offset_first_element_in_bytes + g_ind_1 * {{src}}_stride_y + g_ind_2 * {{src}}_stride_z;
75 __global uchar *dst_addr = {{dst}}_ptr + {{dst}}_offset_first_element_in_bytes + g_ind_1 * {{dst}}_stride_y + g_ind_2 * {{dst}}_stride_z;
Gunes Bayiraecb5d92022-12-18 21:31:29 +000076 Image sum = CONVERT_TENSOR3D_TO_IMAGE_STRUCT({{sum}});
77 VEC_TYPE max_val_vec = (VEC_TYPE)({{MINVAL}});
78)_";
79
80 const bool beta_defined = (_attributes.beta() != 1.f);
81
82 if(beta_defined)
83 {
84 code += R"_(
85 VEC_TYPE beta = (VEC_TYPE){{BETA}};
86)_";
87 }
88
89 constexpr unsigned int _serial_vector_size = 8;
90 const unsigned int reduction_dim_size = _src->dimension(0);
91 const unsigned int vector_size = adjust_vec_size(_serial_vector_size, reduction_dim_size);
92 const bool non_multiple_of_n0 = ((reduction_dim_size % vector_size) != 0);
93
94 if(non_multiple_of_n0)
95 {
96 code += R"_(
97 VEC_TYPE data = VLOAD(N0)(0, (__global {{DATA_TYPE}} *)src_addr);
98 SELECT_TYPE widx = (SELECT_TYPE)PARTIAL_N0 > VEC_OFFS(SELECT_DATA_TYPE({{DATA_TYPE}}), N0);
99 max_val_vec = max(max_val_vec, select((VEC_TYPE)({{MINVAL}}), data, widx));
100)_";
101 }
102
103 code += R"_(
104 for(uint i = PARTIAL_N0; i < {{SRC_WIDTH}}; i += N0)
105 {
106 VEC_TYPE data = VLOAD(N0)(0, (__global {{DATA_TYPE}} *)(src_addr + i * sizeof({{DATA_TYPE}})));
107 max_val_vec = max(data, max_val_vec);
108 }
109
110 {{DATA_TYPE}} max_val = MAX_REDUCE(max_val_vec, N0);
111 VEC_TYPE sum1D = 0;
112)_";
113
114 if(non_multiple_of_n0)
115 {
116 code += R"_(
117 data -= max_val;
118)_";
119 if(beta_defined)
120 {
121 code += R"_(
122 data *= beta;
123)_";
124 }
125
126 if(_attributes.is_log_softmax())
127 {
128 code += R"_(
129 VSTORE_PARTIAL(N0, PARTIAL_N0)
130 (data, 0, (__global {{DATA_TYPE}} *)dst_addr);
131 data = exp(data);
132 data = select(0, data, widx);
133)_";
134 }
135 else
136 {
137 code += R"_(
138 data = exp(data);
139 data = select(0, data, widx);
140 VSTORE_PARTIAL(N0, PARTIAL_N0)
141 (data, 0, (__global {{DATA_TYPE}} *)dst_addr);
142)_";
143 }
144
145 code += R"_(
146 sum1D += data;
147)_";
148 }
Ramy Elgammal002e6532023-01-11 18:48:04 +0000149 code += R"_(
Gunes Bayiraecb5d92022-12-18 21:31:29 +0000150 for(uint i = PARTIAL_N0; i < {{SRC_WIDTH}}; i += N0)
151 {
152 VEC_TYPE data = VLOAD(N0)(0, (__global {{DATA_TYPE}} *)(src_addr + i * sizeof({{DATA_TYPE}})));
153 data -= max_val;
154)_";
155
Ramy Elgammal002e6532023-01-11 18:48:04 +0000156 if(beta_defined)
157 {
Gunes Bayiraecb5d92022-12-18 21:31:29 +0000158 code += R"_(
Ramy Elgammal002e6532023-01-11 18:48:04 +0000159 data *= beta;
Gunes Bayiraecb5d92022-12-18 21:31:29 +0000160)_";
161 }
162
Ramy Elgammal002e6532023-01-11 18:48:04 +0000163 if(_attributes.is_log_softmax())
164 {
165 code += R"_(
166 VSTORE(N0)
167 (data, 0, (__global {{DATA_TYPE}} *)(dst_addr + i * sizeof({{DATA_TYPE}})));
168 data = exp(data);
169)_";
170 }
171 else
172 {
173 code += R"_(
174 data = exp(data);
175 VSTORE(N0)
176 (data, 0, (__global {{DATA_TYPE}} *)(dst_addr + i * sizeof({{DATA_TYPE}})));
177)_";
178 }
179
180 code += R"_(
181 sum1D += data;
182 }
183)_";
184
Gunes Bayiraecb5d92022-12-18 21:31:29 +0000185 code += R"_(
186 *((__global {{DATA_TYPE}} *)sum.ptr) = SUM_REDUCE(sum1D, N0);
187}
188//------------------ END KERNEL {{meta_kernel_id}} ---------------------
189)_";
190
191 return code;
192}
193
194void ClTemplateLogits1DMaxShiftExpSum::declare_variables(GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
195{
196 vtable.declare_variable(
Viet-Hoa Do3558c582022-12-16 14:45:57 +0000197 comp_group,
Gunes Bayiraecb5d92022-12-18 21:31:29 +0000198 _src,
Ramy Elgammal002e6532023-01-11 18:48:04 +0000199 GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_3D),
Gunes Bayiraecb5d92022-12-18 21:31:29 +0000200 "src");
201
202 vtable.declare_variable(
Viet-Hoa Do3558c582022-12-16 14:45:57 +0000203 comp_group,
Gunes Bayiraecb5d92022-12-18 21:31:29 +0000204 _sum,
Ramy Elgammal002e6532023-01-11 18:48:04 +0000205 GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_3D),
Gunes Bayiraecb5d92022-12-18 21:31:29 +0000206 "sum");
207
208 vtable.declare_variable(
Viet-Hoa Do3558c582022-12-16 14:45:57 +0000209 comp_group,
Gunes Bayiraecb5d92022-12-18 21:31:29 +0000210 _dst,
Ramy Elgammal002e6532023-01-11 18:48:04 +0000211 GpuKernelArgumentInfo(GpuKernelArgumentInfo::Type::Tensor_3D),
Gunes Bayiraecb5d92022-12-18 21:31:29 +0000212 "dst");
213}
214
215TagLUT ClTemplateLogits1DMaxShiftExpSum::get_tag_lut(const GpuKernelVariableTable &vtable, const ComponentGroup &comp_group) const
216{
217 ARM_COMPUTE_UNUSED(comp_group);
218
219 TagLUT lut{};
220
221 // Arguments and global shared variables
222 lut["src"] = vtable.get_variable(_src);
223 lut["sum"] = vtable.get_variable(_sum);
224 lut["dst"] = vtable.get_variable(_dst);
225
226 // Local build options
227 lut["meta_kernel_id"] = id();
228
229 const DataType data_type = _src->data_type();
230
231 lut["DATA_TYPE"] = get_cl_type_from_data_type(data_type);
232 lut["BETA"] = float_to_string_with_full_precision(_attributes.beta());
233 lut["MINVAL"] = (data_type == DataType::F16) ? std::string("-HALF_MAX") : std::string("-FLT_MAX");
234 lut["SRC_WIDTH"] = support::cpp11::to_string(_src->dimension(0));
235
236 return lut;
237}
238
239CLBuildOptions ClTemplateLogits1DMaxShiftExpSum::get_build_options(const ComponentGroup &comp_group) const
240{
241 ARM_COMPUTE_UNUSED(comp_group);
242 CLBuildOptions build_opts{};
243
Gunes Bayiraecb5d92022-12-18 21:31:29 +0000244 const unsigned int reduction_dim_size = _src->dimension(0);
245 const unsigned int vector_size = adjust_vec_size(serial_vector_size, reduction_dim_size);
246
247 build_opts.add_option("-DN0=" + support::cpp11::to_string(vector_size));
248 build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string((reduction_dim_size % vector_size)));
249
250 return build_opts;
251}
252
253std::string ClTemplateLogits1DMaxShiftExpSum::get_config_id() const
254{
255 std::string config_id = get_name();
256
257 config_id += "_";
258 config_id += support::cpp11::to_string(_src->dimension(0));
259 config_id += "_";
260 config_id += string_from_data_type(_src->data_type());
261
262 return config_id;
263}
264
265std::set<std::string> ClTemplateLogits1DMaxShiftExpSum::get_headers_list() const
266{
Ramy Elgammal002e6532023-01-11 18:48:04 +0000267 return std::set<std::string>{ "helpers.h", "tile_helpers.h" };
Gunes Bayiraecb5d92022-12-18 21:31:29 +0000268}
269
270Window ClTemplateLogits1DMaxShiftExpSum::get_window() const
271{
272 ARM_COMPUTE_ERROR_ON_MSG(_dst->tensor_shape().total_size() == 0U, "Destination tensor is not initialized");
273
274 Window win = calculate_max_window(*_dst, Steps(_src->dimension(0)));
275 return win.collapse(win, Window::DimZ);
276}
277
278} // namespace dynamic_fusion
279} // namespace experimental
Viet-Hoa Do3558c582022-12-16 14:45:57 +0000280} // namespace arm_compute