blob: 973953624ee9129aaffd04e3c5a6ea0d5ff386de [file] [log] [blame]
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +00001/*
2 * Copyright (c) 2024 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#ifndef ACL_ARM_COMPUTE_RUNTIME_CL_FUNCTIONS_CLSCATTER_H
26#define ACL_ARM_COMPUTE_RUNTIME_CL_FUNCTIONS_CLSCATTER_H
27
28#include "arm_compute/core/Error.h"
29#include "arm_compute/runtime/IFunction.h"
30
31#include <memory>
32
33namespace arm_compute
34{
35class ICLTensor;
36class ITensorInfo;
37struct ScatterInfo;
38class CLCompileContext;
39
40/** Function to compute ScatterND Layer */
41class CLScatter : public IFunction
42{
43public:
44 /** Default Constructor */
45 CLScatter();
46 /** Prevent instances of this class from being copied (As this class contains pointers) */
47 CLScatter(const CLScatter &) = delete;
48 /** Default move constructor */
49 CLScatter(CLScatter &&);
50 /** Prevent instances of this class from being copied (As this class contains pointers) */
51 CLScatter &operator=(const CLScatter &) = delete;
52 /** Default move assignment operator */
53 CLScatter &operator=(CLScatter &&);
54 /** Default destructor */
55 ~CLScatter();
56 /** Initialise the kernel's inputs and outputs
57 *
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000058 * @note Negative indices are treated as out of bounds.
59 *
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000060 * Valid data layouts:
61 * - All
62 *
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000063 * @param[in] compile_context The compile context to be used.
64 * @param[in] src Source tensor. Values used to fill output. Can be nullptr when zero initialization is true.
65 * @param[in] updates Tensor containing values used to update output tensor. Data types supported: same as @p src
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000066 * @param[in] indices Tensor containing Indices to change in the output Tensor. Data types supported : S32
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000067 * @param[out] output Destination tensor. Data types supported: same as @p src.
68 * @param[in] info Scatter info object.
69 */
70 void configure(const CLCompileContext &compile_context,
71 const ICLTensor *src,
72 const ICLTensor *updates,
73 const ICLTensor *indices,
74 ICLTensor *output,
75 const ScatterInfo &info);
76 /** Initialise the kernel's inputs and output
77 *
78 * Similar to @ref CLScatter::configure()
79 */
80 void configure(const ICLTensor *src,
81 const ICLTensor *updates,
82 const ICLTensor *indices,
83 ICLTensor *output,
84 const ScatterInfo &info);
85 /** Static function to check if given info will lead to a valid configuration of @ref CLScatter
86 *
87 * @param[in] src Source tensor.
88 * @param[in] updates Tensor containing values used for updating the output Tensor. Data types supported : same as @p src
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000089 * @param[in] indices Tensor containing Indices to change in the output Tensor. Data types supported : S32
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000090 * @param[in] output Destination tensor. Data types supported: same as @p src.
91 * @param[in] info Scatter info containing type of scatter.
92 *
93 * @return a status
94 */
95 static Status validate(const ITensorInfo *src,
96 const ITensorInfo *updates,
97 const ITensorInfo *indices,
98 const ITensorInfo *output,
99 const ScatterInfo &info);
100
101 // Inherited methods overridden:
102 void run() override;
103
104private:
105 struct Impl;
106 std::unique_ptr<Impl> _impl;
107};
108} // namespace arm_compute
109
110#endif // ACL_ARM_COMPUTE_RUNTIME_CL_FUNCTIONS_CLSCATTER_H