blob: 91e28b58f7e2605306e42c3ea302c7a5ed34d059 [file] [log] [blame]
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +00001/*
2 * Copyright (c) 2024 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ACL_TESTS_VALIDATION_FIXTURES_SCATTERLAYERFIXTURE_H
25#define ACL_TESTS_VALIDATION_FIXTURES_SCATTERLAYERFIXTURE_H
26
27#include "arm_compute/core/Utils.h"
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000028#include "arm_compute/runtime/CL/CLTensorAllocator.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000029#include "tests/Globals.h"
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000030#include "tests/framework/Asserts.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000031#include "tests/framework/Fixture.h"
32#include "tests/validation/Validation.h"
33#include "tests/validation/reference/ScatterLayer.h"
34#include "tests/SimpleTensor.h"
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000035
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000036#include <random>
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000037#include <cstdint>
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000038
39namespace arm_compute
40{
41namespace test
42{
43namespace validation
44{
45template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
46class ScatterGenericValidationFixture : public framework::Fixture
47{
48public:
49 void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
50 {
51 _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, src_qinfo, o_qinfo);
52 _reference = compute_reference(src_shape, updates_shape, indices_shape, out_shape, data_type,scatter_info, src_qinfo , o_qinfo);
53 }
54
55protected:
56 template <typename U>
Mohammed Suhail Munshi0e212362024-04-08 14:38:31 +010057 void fill(U &&tensor, int i, float lo = -10.f, float hi = 10.f)
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000058 {
59 switch(tensor.data_type())
60 {
61 case DataType::F32:
62 {
63 std::uniform_real_distribution<float> distribution(lo, hi);
64 library->fill(tensor, distribution, i);
65 break;
66 }
67 default:
68 {
69 ARM_COMPUTE_ERROR("Unsupported data type.");
70 }
71 }
72 }
73
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000074 // This is used to fill indices tensor with S32 datatype.
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000075 // Used to prevent ONLY having values that are out of bounds.
76 template <typename U>
77 void fill_indices(U &&tensor, int i, const TensorShape &shape)
78 {
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000079 // Calculate max indices the shape should contain. Add an arbitrary value to allow testing for some out of bounds values (In this case min dimension)
80 const int32_t max = std::max({shape[0] , shape[1], shape[2]});
81 library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(-2), static_cast<int32_t>(max));
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000082 }
83
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000084 TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &out_shape, DataType data_type, const ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
85 {
86 // 1. Create relevant tensors using ScatterInfo data structure.
87 // ----------------------------------------------------
88 // In order - src, updates, indices, output.
89 TensorType src = create_tensor<TensorType>(shape_a, data_type, 1, a_qinfo);
90 TensorType updates = create_tensor<TensorType>(shape_b, data_type, 1, a_qinfo);
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000091 TensorType indices = create_tensor<TensorType>(shape_c, DataType::S32, 1, QuantizationInfo());
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000092 TensorType dst = create_tensor<TensorType>(out_shape, data_type, 1, o_qinfo);
93
94 FunctionType scatter;
95
96 // Configure operator
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000097 // When scatter_info.zero_initialization is true, pass nullptr to scatter function.
98 if(info.zero_initialization)
99 {
100 scatter.configure(nullptr, &updates, &indices, &dst, info);
101 }
102 else
103 {
104 scatter.configure(&src, &updates, &indices, &dst, info);
105 }
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000106
107 // Assertions
108 ARM_COMPUTE_ASSERT(src.info()->is_resizable());
109 ARM_COMPUTE_ASSERT(updates.info()->is_resizable());
110 ARM_COMPUTE_ASSERT(indices.info()->is_resizable());
111 ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
112
113 // Allocate tensors
114 src.allocator()->allocate();
115 updates.allocator()->allocate();
116 indices.allocator()->allocate();
117 dst.allocator()->allocate();
118
119 ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
120 ARM_COMPUTE_ASSERT(!updates.info()->is_resizable());
121 ARM_COMPUTE_ASSERT(!indices.info()->is_resizable());
122 ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
123
124 // Fill update (a) and indices (b) tensors.
125 fill(AccessorType(src), 0);
126 fill(AccessorType(updates), 1);
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000127 fill_indices(AccessorType(indices), 2, out_shape);
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000128
129 scatter.run();
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000130 return dst;
131 }
132
133 SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape, const TensorShape &out_shape, DataType data_type,
134 ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
135 {
136 // Output Quantization not currently in use - fixture should be extended to support this.
137 ARM_COMPUTE_UNUSED(o_qinfo);
Mohammed Suhail Munshi0e212362024-04-08 14:38:31 +0100138 TensorShape src_shape = a_shape;
139 TensorShape updates_shape = b_shape;
140 TensorShape indices_shape = c_shape;
141
142 // 1. Collapse batch index into a single dim if necessary for update tensor and indices tensor.
143 if(c_shape.num_dimensions() >= 3)
144 {
145 indices_shape = indices_shape.collapsed_from(1);
146 updates_shape = updates_shape.collapsed_from(updates_shape.num_dimensions() - 2); // Collapses from last 2 dims
147 }
148
149 // 2. Collapse data dims into a single dim.
150 // Collapse all src dims into 2 dims. First one holding data, the other being the index we iterate over.
151 src_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse all data dims into single dim.
152 src_shape = src_shape.collapsed_from(1); // Collapse all index dims into a single dim
153 updates_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse data dims (all except last dim which is batch dim)
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000154
155 // Create reference tensors
156 SimpleTensor<T> src{ a_shape, data_type, 1, a_qinfo };
157 SimpleTensor<T> updates{b_shape, data_type, 1, QuantizationInfo() };
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000158 SimpleTensor<int32_t> indices{ c_shape, DataType::S32, 1, QuantizationInfo() };
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000159
160 // Fill reference
161 fill(src, 0);
162 fill(updates, 1);
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000163 fill_indices(indices, 2, out_shape);
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000164
165 // Calculate individual reference.
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000166 return reference::scatter_layer<T>(src, updates, indices, out_shape, info);
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000167 }
168
169 TensorType _target{};
170 SimpleTensor<T> _reference{};
171};
172
173// This fixture will use the same shape for updates as indices.
174template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
175class ScatterValidationFixture : public ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>
176{
177public:
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000178 void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init)
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000179 {
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000180 ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape, indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), QuantizationInfo(), QuantizationInfo());
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000181 }
182};
183
184} // namespace validation
185} // namespace test
186} // namespace arm_compute
187#endif // ACL_TESTS_VALIDATION_FIXTURES_SCATTERLAYERFIXTURE_H