blob: a11ecd7e6a8bed8b7173ba8df2ae636d99c82e0d [file] [log] [blame]
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +00001/*
2 * Copyright (c) 2024 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "src/gpu/cl/operators/ClScatter.h"
25
26#include "arm_compute/core/Error.h"
27#include "arm_compute/runtime/CL/CLScheduler.h"
28
29#include "src/common/utils/Log.h"
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000030#include "src/gpu/cl/kernels/ClCopyKernel.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000031#include "src/gpu/cl/kernels/ClFillKernel.h"
32#include "src/gpu/cl/kernels/ClScatterKernel.h"
33
34namespace arm_compute
35{
36namespace opencl
37{
38using namespace arm_compute::opencl::kernels;
39
40ClScatter::ClScatter()
41{
42}
43
44Status ClScatter::validate(const ITensorInfo *src,
45 const ITensorInfo *updates,
46 const ITensorInfo *indices,
47 const ITensorInfo *dst,
48 const ScatterInfo &info)
49{
50 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(updates, indices, dst);
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000051 if (src != nullptr)
52 {
53 // Check dst/src are same shape and datatype.
54 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(src->tensor_shape(), dst->tensor_shape());
55 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, updates, dst);
56 ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClCopyKernel::validate(src, dst)); // Validate Copy kernel
57 }
58 if (src != dst)
59 {
60 ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClFillKernel::validate(dst, PixelValue(0.0f))); // Validate Fill kernel.
61 }
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000062
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000063 return kernels::ClScatterKernel::validate(updates, indices, dst, info);
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000064}
65
66void ClScatter::configure(const CLCompileContext &compile_context,
67 const ITensorInfo *src,
68 const ITensorInfo *updates,
69 const ITensorInfo *indices,
70 ITensorInfo *dst,
71 const ScatterInfo &info)
72{
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000073 ARM_COMPUTE_ERROR_ON_NULLPTR(updates, indices, dst);
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000074 ARM_COMPUTE_LOG_PARAMS(src, indices, dst, info);
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000075
76 // Perform validation step
77 ARM_COMPUTE_ERROR_THROW_ON(validate(src, updates, indices, dst, info));
78 _fill_zero = info.zero_initialization;
79
80 // If necessary, create fill kernel to fill dst tensor.
81 if (_fill_zero)
82 {
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000083 auto f = std::make_unique<kernels::ClFillKernel>();
84 f->configure(compile_context, dst, PixelValue(0.0f));
85 _fill_kernel = std::move(f);
86 }
87 else if (src != dst) // Check whether copying is necessary
88 {
89 // Fill dst with src copy here.
90 auto j = std::make_unique<kernels::ClCopyKernel>();
91 j->configure(compile_context, src, dst);
92 _copy_kernel = std::move(j);
93 _run_copy = true;
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000094 }
95
96 // Configure ClScatterKernel
97 auto k = std::make_unique<kernels::ClScatterKernel>();
98 k->set_target(CLScheduler::get().target());
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000099 k->configure(compile_context, updates, indices, dst, info);
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000100 _scatter_kernel = std::move(k);
101}
102
103void ClScatter::run(ITensorPack &tensors)
104{
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000105 // Get tensors.
106 auto src = tensors.get_const_tensor(ACL_SRC_0);
107 auto updates = tensors.get_const_tensor(ACL_SRC_1);
108 auto indices = tensors.get_const_tensor(ACL_SRC_2);
109 auto dst = tensors.get_tensor(ACL_DST);
110
111 if (_fill_zero)
112 {
113 // Fill destination tensor with 0 values if zero init.
114 ITensorPack fill_pack{{ACL_SRC, dst}};
115 CLScheduler::get().enqueue_op(*_fill_kernel, fill_pack, false);
116 }
117
118 if (_run_copy)
119 {
120 // copy src to dst before scatter op.
121 ITensorPack copy_pack{{ACL_SRC, src}, {ACL_DST, dst}};
122 CLScheduler::get().enqueue_op(*_copy_kernel, copy_pack, false);
123 }
124
125 ITensorPack scatter_pack{{ACL_SRC_0, updates}, {ACL_SRC_1, indices}, {ACL_DST, dst}};
126 CLScheduler::get().enqueue_op(*_scatter_kernel, scatter_pack, false);
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000127}
128
129} // namespace opencl
130} // namespace arm_compute