blob: 283022e8e2b19dd43a0e2bd5a1796aff7c2badbc [file] [log] [blame]
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +00001/*
2 * Copyright (c) 2024 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "ScatterLayer.h"
25#include "tests/validation/Helpers.h"
Mohammed Suhail Munshi0e212362024-04-08 14:38:31 +010026#include "arm_compute/core/TensorShape.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000027
28namespace arm_compute
29{
30namespace test
31{
32namespace validation
33{
34namespace reference
35{
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000036namespace
37{
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000038
39template <typename T>
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000040T reduce_op(const T &current,const T &update,const ScatterFunction func)
41{
42 switch(func)
43 {
44 case ScatterFunction::Update:
45 return update;
46 break;
47 case ScatterFunction::Add:
48 return current + update;
49 break;
50 case ScatterFunction::Sub:
51 return current - update;
52 break;
53 case ScatterFunction::Max:
54 return std::max(current, update);
55 break;
56 case ScatterFunction::Min:
57 return std::min(current, update);
58 break;
59 default:
60 ARM_COMPUTE_ERROR("Unsupported Scatter function");
61 break;
62 }
63}
64
65template float reduce_op(const float &current,const float &update,const ScatterFunction func);
66}
67
Mohammed Suhail Munshi0e212362024-04-08 14:38:31 +010068// NOTE: This function expects collapsed tensors as input.
69// Batch dims for update/indices tensors should be collapsed into a single dim.
70// Data dims should be collapsed into a single dim for both update and src tensors prior to calling this function.
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000071template <typename T>
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000072SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000073{
Mohammed Suhail Munshi0e212362024-04-08 14:38:31 +010074 // 1. If zero initialization variable is false, copy src data to dst.
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000075 SimpleTensor<T> dst{ out_shape, src.data_type(), 1 };
Mohammed Suhail Munshi0e212362024-04-08 14:38:31 +010076 if(!info.zero_initialization)
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000077 {
78 std::copy_n(src.data(), src.num_elements(), dst.data());
79 }
80
Mohammed Suhail Munshi0e212362024-04-08 14:38:31 +010081 // Number of elements between each value of the dim being iterated through
82 const unsigned int data_stride = updates.shape().total_size_lower(updates.shape().num_dimensions() - 1);
83 const unsigned int no_output_dims = out_shape.num_dimensions();
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000084
Mohammed Suhail Munshi0e212362024-04-08 14:38:31 +010085 // Calculate output stride at given index for all output dims.
86 std::vector<unsigned int> out_stride_at_idx(no_output_dims);
87 for (unsigned int i = 0 ; i < no_output_dims; i++)
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000088 {
Mohammed Suhail Munshi0e212362024-04-08 14:38:31 +010089 out_stride_at_idx[i] = out_shape.total_size_lower(i);
90 }
91
92 const unsigned int indices_x_dim = static_cast<unsigned int>(indices.shape()[0]);
93 const unsigned int indices_y_dim = static_cast<unsigned int>(indices.shape()[1]);
94
95 // 2. Iterate over indices tensor y-dim and replace sections of dst tensor with relevant areas of update tensor.
96 for(unsigned int i = 0; i < indices_y_dim; i++)
97 {
98 // NOTE : Currently, indices.shape() == [X, Y, 1, 1], where X is the indices dim and Y is the batch dim
99 // Starting index for both the update and indices tensors.
100 const unsigned int update_dim_start = i * data_stride;
101 const unsigned int indices_dim_start = i * indices_x_dim;
102 bool out_of_bounds = false;
103 unsigned int out_offset_acc = 0;
104
105 // Iterate over each indices value for the relevant batch and accumulate the offset.
106 for(unsigned int j = 0; j < indices_x_dim; j++)
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000107 {
Mohammed Suhail Munshi0e212362024-04-08 14:38:31 +0100108 // Get first index value with i * indices_x_dim (iterating through y-dim/batch idx), then iterate through x dim by adding k
109 const int index_value = indices[indices_dim_start + j];
110 const unsigned int out_dim = no_output_dims - (j+1); // Calculate corresponding output dim to current index value.
111 if(index_value < static_cast<int>(out_shape[out_dim]) && index_value >= 0)
112 {
113 out_offset_acc += (index_value * out_stride_at_idx[out_dim]); // offset accumulation
114 }
115 else
116 {
117 out_of_bounds = true;
118 break;
119 }
120 }
121
122 // If not out of bounds, copy update tensor elements to output
123 if(!out_of_bounds)
124 {
125 for (unsigned int j = 0 ; j < data_stride; j++)
126 {
127 dst[out_offset_acc + j] = reduce_op(dst[out_offset_acc + j], updates[update_dim_start + j], info.func);
128 }
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000129 }
130 }
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000131 return dst;
132}
133
134template <typename T>
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000135SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000136{
137 return scatter_layer_internal<T>(src, updates, indices, out_shape, info);
138}
139
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000140template SimpleTensor<float> scatter_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000141
142} // namespace reference
143} // namespace validation
144} // namespace test
145} // namespace arm_compute