blob: 7543b46bb12c354852e7cc00ff8923850ff223f8 [file] [log] [blame]
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +00001/*
2 * Copyright (c) 2024 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "ScatterLayer.h"
25#include "tests/validation/Helpers.h"
26
27namespace arm_compute
28{
29namespace test
30{
31namespace validation
32{
33namespace reference
34{
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000035namespace
36{
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000037
38template <typename T>
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000039T reduce_op(const T &current,const T &update,const ScatterFunction func)
40{
41 switch(func)
42 {
43 case ScatterFunction::Update:
44 return update;
45 break;
46 case ScatterFunction::Add:
47 return current + update;
48 break;
49 case ScatterFunction::Sub:
50 return current - update;
51 break;
52 case ScatterFunction::Max:
53 return std::max(current, update);
54 break;
55 case ScatterFunction::Min:
56 return std::min(current, update);
57 break;
58 default:
59 ARM_COMPUTE_ERROR("Unsupported Scatter function");
60 break;
61 }
62}
63
64template float reduce_op(const float &current,const float &update,const ScatterFunction func);
65}
66
67// Note : This function currently only supports 1D src, 1D updates, 2D indices, 1D output tensors.
68template <typename T>
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000069SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000070{
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000071 SimpleTensor<T> dst{ out_shape, src.data_type(), 1 };
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000072
73 // 1. If zero initialization variable is true, fill dst with 0 values. Else copy src data to dst.
74 if(info.zero_initialization)
75 {
76 for (int i = 0; i < src.num_elements(); ++i)
77 {
78 dst[i] = static_cast<T>(0);
79 }
80 }
81 else
82 {
83 std::copy_n(src.data(), src.num_elements(), dst.data());
84 }
85
86 // 2. Get max index of output tensor, then iterate over index tensor.
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000087 const int x_bound = static_cast<int>(dst.shape().x());
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000088
89
90 for(int i = 0; i < indices.num_elements(); ++i)
91 {
92 // 3. Check whether index is out of bounds for dst, if not then apply reduce op.
93 const auto index = indices[i];
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000094 if (index < x_bound && index >= 0) // Note : we ignore negative index values.
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000095 {
96 dst[index] = reduce_op(dst[index], updates[i], info.func);
97 }
98 }
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000099 return dst;
100}
101
102template <typename T>
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000103SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000104{
105 return scatter_layer_internal<T>(src, updates, indices, out_shape, info);
106}
107
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000108template SimpleTensor<float> scatter_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000109
110} // namespace reference
111} // namespace validation
112} // namespace test
113} // namespace arm_compute