blob: 1c90d208bdbf2902637e6e1a005d027917b32166 [file] [log] [blame]
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +00001/*
2 * Copyright (c) 2024 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#ifndef ACL_ARM_COMPUTE_RUNTIME_CL_FUNCTIONS_CLSCATTER_H
26#define ACL_ARM_COMPUTE_RUNTIME_CL_FUNCTIONS_CLSCATTER_H
27
28#include "arm_compute/core/Error.h"
29#include "arm_compute/runtime/IFunction.h"
30
31#include <memory>
32
33namespace arm_compute
34{
35class ICLTensor;
36class ITensorInfo;
37struct ScatterInfo;
38class CLCompileContext;
39
40/** Function to compute ScatterND Layer */
41class CLScatter : public IFunction
42{
43public:
44 /** Default Constructor */
45 CLScatter();
46 /** Prevent instances of this class from being copied (As this class contains pointers) */
47 CLScatter(const CLScatter &) = delete;
48 /** Default move constructor */
49 CLScatter(CLScatter &&);
50 /** Prevent instances of this class from being copied (As this class contains pointers) */
51 CLScatter &operator=(const CLScatter &) = delete;
52 /** Default move assignment operator */
53 CLScatter &operator=(CLScatter &&);
54 /** Default destructor */
55 ~CLScatter();
56 /** Initialise the kernel's inputs and outputs
57 *
58 * Valid data layouts:
59 * - All
60 *
61 *
62 * @param[in] compile_context The compile context to be used.
63 * @param[in] src Source tensor. Values used to fill output. Can be nullptr when zero initialization is true.
64 * @param[in] updates Tensor containing values used to update output tensor. Data types supported: same as @p src
65 * @param[in] indices Tensor containing Indices to change in the output Tensor. Data types supported : U32
66 * @param[out] output Destination tensor. Data types supported: same as @p src.
67 * @param[in] info Scatter info object.
68 */
69 void configure(const CLCompileContext &compile_context,
70 const ICLTensor *src,
71 const ICLTensor *updates,
72 const ICLTensor *indices,
73 ICLTensor *output,
74 const ScatterInfo &info);
75 /** Initialise the kernel's inputs and output
76 *
77 * Similar to @ref CLScatter::configure()
78 */
79 void configure(const ICLTensor *src,
80 const ICLTensor *updates,
81 const ICLTensor *indices,
82 ICLTensor *output,
83 const ScatterInfo &info);
84 /** Static function to check if given info will lead to a valid configuration of @ref CLScatter
85 *
86 * @param[in] src Source tensor.
87 * @param[in] updates Tensor containing values used for updating the output Tensor. Data types supported : same as @p src
88 * @param[in] indices Tensor containing Indices to change in the output Tensor. Data types supported : U32
89 * @param[in] output Destination tensor. Data types supported: same as @p src.
90 * @param[in] info Scatter info containing type of scatter.
91 *
92 * @return a status
93 */
94 static Status validate(const ITensorInfo *src,
95 const ITensorInfo *updates,
96 const ITensorInfo *indices,
97 const ITensorInfo *output,
98 const ScatterInfo &info);
99
100 // Inherited methods overridden:
101 void run() override;
102
103private:
104 struct Impl;
105 std::unique_ptr<Impl> _impl;
106};
107} // namespace arm_compute
108
109#endif // ACL_ARM_COMPUTE_RUNTIME_CL_FUNCTIONS_CLSCATTER_H