blob: 9c25b63c72fb73d760b739f2dc566342a32ac9f3 [file] [log] [blame]
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +00001/*
2 * Copyright (c) 2024 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "src/gpu/cl/kernels/ClScatterKernel.h"
25
26#include "arm_compute/core/CL/ICLTensor.h"
27#include "arm_compute/core/ITensorPack.h"
28#include "arm_compute/core/TensorInfo.h"
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000029#include "arm_compute/core/Utils.h"
Gunes Bayirada32002024-04-24 10:27:13 +010030#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000031
32#include "src/common/utils/Log.h"
33#include "src/core/helpers/WindowHelpers.h"
34#include "support/Cast.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000035
Gunes Bayirada32002024-04-24 10:27:13 +010036#include <cstdint>
37
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000038namespace arm_compute
39{
40namespace opencl
41{
42namespace kernels
43{
Gunes Bayirada32002024-04-24 10:27:13 +010044
45namespace
46{
47constexpr int max_index_length = 5;
48} // namespace
49
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000050ClScatterKernel::ClScatterKernel()
51{
52}
53
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000054Status ClScatterKernel::validate(const ITensorInfo *updates,
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000055 const ITensorInfo *indices,
56 const ITensorInfo *dst,
57 const ScatterInfo &info)
58{
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000059 ARM_COMPUTE_UNUSED(info);
60
Gunes Bayirada32002024-04-24 10:27:13 +010061 const TensorShape &ind_shape = indices->tensor_shape();
62 const TensorShape &upt_shape = updates->tensor_shape();
63 const TensorShape &dst_shape = dst->tensor_shape();
64
65 const int32_t upt_dims = upt_shape.num_dimensions();
66 const int32_t dst_dims = dst_shape.num_dimensions();
67 const int32_t ind_dims = ind_shape.num_dimensions();
68
69 const int32_t index_len = ind_shape[0];
70
71 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst);
72 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32);
73 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32);
74 ARM_COMPUTE_RETURN_ERROR_ON_MSG(ind_dims > 2, "Only 2D indices tensors are currently supported.");
75 ARM_COMPUTE_RETURN_ERROR_ON_MSG(
76 ind_shape[1] != upt_shape[upt_dims - 1],
77 "Height of indices tensor should match size of highest dimension in updates tensor.");
78 ARM_COMPUTE_RETURN_ERROR_ON_MSG(upt_dims > dst_dims, "Update tensor cannot have more dims than output tensor.");
79
80 ARM_COMPUTE_RETURN_ERROR_ON_MSG(index_len > max_index_length, "Maximum supported index length is 5!");
81 ARM_COMPUTE_RETURN_ERROR_ON(index_len != dst_dims - upt_dims + 1);
82
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000083 return Status{};
84}
Gunes Bayirada32002024-04-24 10:27:13 +010085
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000086void ClScatterKernel::configure(const ClCompileContext &compile_context,
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000087 const ITensorInfo *updates,
88 const ITensorInfo *indices,
89 ITensorInfo *dst,
90 const ScatterInfo &info)
91{
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000092 ARM_COMPUTE_ERROR_ON_NULLPTR(updates, dst, indices);
93 ARM_COMPUTE_LOG_PARAMS(updates, indices, dst, info);
94
Gunes Bayirada32002024-04-24 10:27:13 +010095 const TensorShape &dst_shape = dst->tensor_shape();
96
97 const bool is_scalar_block = updates->num_dimensions() == 1;
98 const int n0 = adjust_vec_size(16 / updates->element_size(), is_scalar_block ? 1 : updates->dimension(0));
99
100 const int partial_n0 = updates->dimension(0) % n0;
101
102 // The GWS will be 2D [x, y]
103 // x-dimension refers to the x coordinate of the dst tensor
104 // y-dimension refers to the collapsed y-coordinate of the data part of the dst tensor
105 Window win = calculate_max_window(dst_shape, Steps(n0));
106 const int index_len = indices->dimension(0);
107
108 // Collapse the dimensions corresponding to indices in the execution window
109 for (int i = 0; i < index_len; ++i)
110 {
111 win.set(dst->num_dimensions() - (i + 1), Window::Dimension(0, 1, 1));
112 }
113
114 win = win.collapse(win, 1);
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000115
116 // Set build options
117 CLBuildOptions build_opts;
118 build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(dst->data_type()));
Gunes Bayirada32002024-04-24 10:27:13 +0100119
120 const int num_dims = dst->num_dimensions();
121
122 build_opts.add_option("-DNUM_INDICES=" + support::cpp11::to_string(indices->dimension(1)));
123 build_opts.add_option("-DINDEX_LENGTH=" + support::cpp11::to_string(index_len));
124
125 // We provide 5 variables to use in a constant array
126 for (int i = 1; i <= max_index_length; i++)
127 {
128 build_opts.add_option("-DOUT_SHAPE_N_MINUS_" + support::cpp11::to_string(i) + "=" +
129 support::cpp11::to_string(dst_shape[std::max(num_dims - i, 0)]));
130 }
131
132 build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
133 build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_n0));
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000134
135 switch (info.func)
136 {
137 case ScatterFunction::Update:
138 build_opts.add_option("-DSCATTER_FUNCTION=UPDATE_OP");
Gunes Bayirada32002024-04-24 10:27:13 +0100139 build_opts.add_option("-DSKIP_OUTPUT_READ");
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000140 break;
141 case ScatterFunction::Add:
142 build_opts.add_option("-DSCATTER_FUNCTION=ADD_OP");
143 break;
144 case ScatterFunction::Sub:
145 build_opts.add_option("-DSCATTER_FUNCTION=SUB_OP");
146 break;
147 case ScatterFunction::Max:
148 build_opts.add_option("-DSCATTER_FUNCTION=MAX_OP");
149 break;
150 case ScatterFunction::Min:
151 build_opts.add_option("-DSCATTER_FUNCTION=MIN_OP");
152 break;
153 default:
154 ARM_COMPUTE_ERROR("Not implemented");
155 }
156
157 // Create kernel
Gunes Bayirada32002024-04-24 10:27:13 +0100158 std::string kernel_name = "scatter_mp1d_2d_mpnd";
159 build_opts.add_option("-D" + upper_string(kernel_name));
160
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000161 ICLKernel::configure_internal(win);
162 _kernel = create_kernel(compile_context, kernel_name, build_opts.options());
Gunes Bayirada32002024-04-24 10:27:13 +0100163
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000164 // Set config_id for enabling LWS tuning
165 _config_id = kernel_name;
166 _config_id += "_";
167 _config_id += lower_string(string_from_data_type(updates->data_type()));
168 _config_id += "_";
169 _config_id += support::cpp11::to_string(dst->dimension(1));
170 _config_id += "_";
171 _config_id += support::cpp11::to_string(dst->dimension(0));
172 _config_id += "_";
173 _config_id += support::cpp11::to_string(dst->dimension(2));
174 _config_id += "_";
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000175}
176
177void ClScatterKernel::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue)
178{
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000179 const auto updates =
180 utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_0));
181 const auto indices =
182 utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_1));
183 auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST));
Gunes Bayirada32002024-04-24 10:27:13 +0100184
185 const ITensorInfo *dst_info = dst->info();
186 const int num_dims = dst_info->num_dimensions();
187
188 const int index_len = indices->info()->dimension(0);
189
190 // calculate m-dimensional data block strides in updates and destination tensors
191 const int upt_block_stride = updates->info()->strides_in_bytes()[updates->info()->num_dimensions() - 1];
192 const int out_block_stride = dst_info->strides_in_bytes()[num_dims - index_len];
193
194 unsigned int idx = 0;
195
196 add_2D_tensor_argument(idx, updates, window);
197 add_2D_tensor_argument(idx, indices, window);
198 add_2D_tensor_argument(idx, dst, window);
199
200 _kernel.setArg<cl_int>(idx++, upt_block_stride);
201 _kernel.setArg<cl_int>(idx++, out_block_stride);
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000202
203 enqueue(queue, *this, window, lws_hint());
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000204}
205
206} // namespace kernels
207} // namespace opencl
208} // namespace arm_compute