blob: 433f7ca3a4f76ac46f8b464f371246da759e925d [file] [log] [blame]
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +00001/*
2 * Copyright (c) 2024 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#ifndef ACL_SRC_GPU_CL_OPERATORS_CLSCATTER_H
26#define ACL_SRC_GPU_CL_OPERATORS_CLSCATTER_H
27
28#include "arm_compute/function_info/ScatterInfo.h"
29
30#include "src/gpu/cl/IClKernel.h"
31#include "src/gpu/cl/IClOperator.h"
32
33#include <memory>
34
35namespace arm_compute
36{
37namespace opencl
38{
39// Forward declaration
40class ClFillKernel;
41class ClScatterKernel;
42
43/** Basic operator to execute Scatter on OpenCL. This operator calls the following OpenCL kernels:
44 *
45 * -# @ref kernels::ClScatterKernel
46 */
47class ClScatter : public IClOperator
48{
49public:
50 /** Constructor */
51 ClScatter();
52 /** Default destructor */
53 ~ClScatter() = default;
54 /** Initialise the kernel's inputs and output
55 *
56 * Valid data layouts:
57 * - All
58 *
59 * @note indices must always be U32
60 * @note src, updates and dst tensors must be same datatype.
61 *
62 * @param[in] compile_context The compile context to be used.
63 * @param[in] src Source input tensor info. Can be nullptr when using "Add" Scatter Function with zero initialization.
64 * @param[in] updates Tensor info for tensor storing update values to use for scatter function. Data types supported: same as @p src.
65 * @param[in] indices Tensor info for tensor storing indices to use for scatter function. Data types supported: U32 only.
66 * @param[out] dst Output tensor to store the result of the Scatter Function. Data types supported: same as @p src and @p updates.
67 * @param[in] Scatter_info Contains Scatter operation information described in @ref ScatterInfo.
68 */
69 void configure(const CLCompileContext &compile_context,
70 const ITensorInfo *src,
71 const ITensorInfo *updates,
72 const ITensorInfo *indices,
73 ITensorInfo *dst,
74 const ScatterInfo &Scatter_info);
75 /** Static function to check if given info will lead to a valid configuration
76 *
77 * Similar to @ref ClScatter::configure()
78 *
79 * @return a status
80 */
81 static Status validate(const ITensorInfo *src,
82 const ITensorInfo *updates,
83 const ITensorInfo *indices,
84 const ITensorInfo *dst,
85 const ScatterInfo &Scatter_info);
86 // Inherited methods overridden:
87 void run(ITensorPack &tensors) override;
88
89private:
90 std::unique_ptr<opencl::IClKernel> _scatter_kernel{nullptr};
91 std::unique_ptr<opencl::IClKernel> _fill_kernel{nullptr};
92 bool _fill_zero{false};
93};
94} // namespace opencl
95} // namespace arm_compute
96#endif // ACL_SRC_GPU_CL_OPERATORS_CLSCATTER_H