blob: 2970d82572d9374e0d755b262068686dad1ad0de [file] [log] [blame]
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +00001/*
2 * Copyright (c) 2024 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/CLTensor.h"
25#include "arm_compute/runtime/CL/functions/CLScatter.h"
26#include "tests/validation/fixtures/ScatterLayerFixture.h"
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000027#include "tests/datasets/ScatterDataset.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000028#include "tests/CL/CLAccessor.h"
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000029#include "arm_compute/function_info/ScatterInfo.h"
30#include "tests/framework/Asserts.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000031#include "tests/framework/Macros.h"
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000032#include "tests/framework/datasets/Datasets.h"
33#include "tests/validation/Validation.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000034
35namespace arm_compute
36{
37namespace test
38{
39namespace validation
40{
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000041namespace
42{
43RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for fp32 data type */
Gunes Bayir301e33f2024-04-29 17:00:14 +010044RelativeTolerance<float> tolerance_f16(0.02f); /**< Tolerance value for comparing reference's output against implementation's output for fp16 data type */
45RelativeTolerance<int32_t> tolerance_int(0); /**< Tolerance value for comparing reference's output against implementation's output for integer data types */
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000046} // namespace
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000047
48template <typename T>
49using CLScatterLayerFixture = ScatterValidationFixture<CLTensor, CLAccessor, CLScatter, T>;
50
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000051using framework::dataset::make;
52
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000053TEST_SUITE(CL)
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000054TEST_SUITE(Scatter)
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000055DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip(
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000056 make("InputInfo", { TensorInfo(TensorShape(9U), 1, DataType::F32), // Mismatching data types
Gunes Bayirada32002024-04-24 10:27:13 +010057 TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid
Gunes Bayir301e33f2024-04-29 17:00:14 +010058 TensorInfo(TensorShape(15U), 1, DataType::U8), // Valid
Gunes Bayirada32002024-04-24 10:27:13 +010059 TensorInfo(TensorShape(8U), 1, DataType::F32),
60 TensorInfo(TensorShape(217U), 1, DataType::F32), // Mismatch input/output dims.
61 TensorInfo(TensorShape(217U), 1, DataType::F32), // Updates dim higher than Input/Output dims.
62 TensorInfo(TensorShape(12U), 1, DataType::F32), // Indices wrong datatype.
63 TensorInfo(TensorShape(9U, 3U, 4U), 1, DataType::F32), // Number of updates != number of indices
64 TensorInfo(TensorShape(17U, 3U, 3U, 2U), 1, DataType::F32), // index_len != (dst_dims - upt_dims + 1)
65 TensorInfo(TensorShape(17U, 3U, 3U, 2U, 2U, 2U), 1, DataType::F32), // index_len > 5
66 }),
67 make("UpdatesInfo",{TensorInfo(TensorShape(3U), 1, DataType::F16),
68 TensorInfo(TensorShape(15U), 1, DataType::F32),
Gunes Bayir301e33f2024-04-29 17:00:14 +010069 TensorInfo(TensorShape(15U), 1, DataType::U8),
Gunes Bayirada32002024-04-24 10:27:13 +010070 TensorInfo(TensorShape(2U), 1, DataType::F32),
71 TensorInfo(TensorShape(217U), 1, DataType::F32),
72 TensorInfo(TensorShape(217U, 3U), 1, DataType::F32),
73 TensorInfo(TensorShape(2U), 1, DataType::F32),
74 TensorInfo(TensorShape(9U, 3U, 2U), 1, DataType::F32),
75 TensorInfo(TensorShape(17U, 3U, 2U), 1, DataType::F32),
76 TensorInfo(TensorShape(1U), 1, DataType::F32),
77 }),
78 make("IndicesInfo",{TensorInfo(TensorShape(1U, 3U), 1, DataType::S32),
79 TensorInfo(TensorShape(1U, 15U), 1, DataType::S32),
Gunes Bayir301e33f2024-04-29 17:00:14 +010080 TensorInfo(TensorShape(1U, 15U), 1, DataType::S32),
Gunes Bayirada32002024-04-24 10:27:13 +010081 TensorInfo(TensorShape(1U, 2U), 1, DataType::S32),
82 TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
83 TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
84 TensorInfo(TensorShape(1U, 2U), 1 , DataType::F32),
85 TensorInfo(TensorShape(1U, 4U), 1, DataType::S32),
86 TensorInfo(TensorShape(3U, 2U), 1, DataType::S32),
87 TensorInfo(TensorShape(6U, 2U), 1, DataType::S32),
88 }),
89 make("OutputInfo",{TensorInfo(TensorShape(9U), 1, DataType::F16),
90 TensorInfo(TensorShape(15U), 1, DataType::F32),
Gunes Bayir301e33f2024-04-29 17:00:14 +010091 TensorInfo(TensorShape(15U), 1, DataType::U8),
Gunes Bayirada32002024-04-24 10:27:13 +010092 TensorInfo(TensorShape(8U), 1, DataType::F32),
93 TensorInfo(TensorShape(271U, 3U), 1, DataType::F32),
94 TensorInfo(TensorShape(271U), 1, DataType::F32),
95 TensorInfo(TensorShape(12U), 1, DataType::F32),
96 TensorInfo(TensorShape(9U, 3U, 4U), 1, DataType::F32),
97 TensorInfo(TensorShape(17U, 3U, 3U, 2U), 1, DataType::F32),
98 TensorInfo(TensorShape(17U, 3U, 3U, 2U, 2U, 2U), 1, DataType::F32),
99 }),
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000100 make("ScatterInfo",{ ScatterInfo(ScatterFunction::Add, false),
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000101 ScatterInfo(ScatterFunction::Max, false),
Gunes Bayir301e33f2024-04-29 17:00:14 +0100102 ScatterInfo(ScatterFunction::Max, false),
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000103 ScatterInfo(ScatterFunction::Min, false),
104 ScatterInfo(ScatterFunction::Add, false),
105 ScatterInfo(ScatterFunction::Update, false),
106 ScatterInfo(ScatterFunction::Sub, false),
Gunes Bayirada32002024-04-24 10:27:13 +0100107 ScatterInfo(ScatterFunction::Sub, false),
108 ScatterInfo(ScatterFunction::Update, false),
109 ScatterInfo(ScatterFunction::Update, false),
110 }),
Gunes Bayir301e33f2024-04-29 17:00:14 +0100111 make("Expected", { false, true, true, true, false, false, false, false, false, false })),
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000112 input_info, updates_info, indices_info, output_info, scatter_info, expected)
113{
Gunes Bayirada32002024-04-24 10:27:13 +0100114 const Status status = CLScatter::validate(&input_info, &updates_info, &indices_info, &output_info, scatter_info);
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000115 ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000116}
117
Gunes Bayirada32002024-04-24 10:27:13 +0100118const auto allScatterFunctions = make("ScatterFunction",
119 {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max });
120
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000121TEST_SUITE(Float)
122TEST_SUITE(FP32)
Gunes Bayirada32002024-04-24 10:27:13 +0100123FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
124 combine(datasets::Small1DScatterDataset(),
125 make("DataType", {DataType::F32}),
126 allScatterFunctions,
127 make("ZeroInit", {false}),
128 make("Inplace", {false})))
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000129{
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000130 validate(CLAccessor(_target), _reference, tolerance_f32);
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000131}
132
133// With this test, src should be passed as nullptr.
Gunes Bayirada32002024-04-24 10:27:13 +0100134FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
135 combine(datasets::Small1DScatterDataset(),
136 make("DataType", {DataType::F32}),
137 make("ScatterFunction", {ScatterFunction::Add}),
138 make("ZeroInit", {true}),
139 make("Inplace", {false})))
140{
141 validate(CLAccessor(_target), _reference, tolerance_f32);
142}
143
144// Updates/src/dst have same no. dims.
145FIXTURE_DATA_TEST_CASE(RunSmallMultiDim, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
146 combine(datasets::SmallScatterMultiDimDataset(),
147 make("DataType", {DataType::F32}),
148 allScatterFunctions,
149 make("ZeroInit", {false}),
150 make("Inplace", {false})))
151{
152 validate(CLAccessor(_target), _reference, tolerance_f32);
153}
154
155// m+1-D to m+n-D cases
156FIXTURE_DATA_TEST_CASE(RunSmallMultiIndices, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
157 combine(datasets::SmallScatterMultiIndicesDataset(),
158 make("DataType", {DataType::F32}),
159 make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add }),
160 make("ZeroInit", {false}),
161 make("Inplace", {false, true})))
162{
163 validate(CLAccessor(_target), _reference, tolerance_f32);
164}
165
166// m+k, k-1-D m+n-D case
167FIXTURE_DATA_TEST_CASE(RunSmallBatchedMultiIndices, CLScatterLayerFixture<float>, framework::DatasetMode::DISABLED,
168 combine(datasets::SmallScatterBatchedDataset(),
169 make("DataType", {DataType::F32}),
170 make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add }),
171 make("ZeroInit", {false}),
172 make("Inplace", {false})))
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000173{
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000174 validate(CLAccessor(_target), _reference, tolerance_f32);
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000175}
Mohammed Suhail Munshi0e212362024-04-08 14:38:31 +0100176
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000177TEST_SUITE_END() // FP32
Gunes Bayir301e33f2024-04-29 17:00:14 +0100178
179TEST_SUITE(FP16)
180FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<half>, framework::DatasetMode::PRECOMMIT,
181 combine(datasets::SmallScatterMixedDataset(),
182 make("DataType", {DataType::F16}),
183 allScatterFunctions,
184 make("ZeroInit", {false}),
185 make("Inplace", {false})))
186{
187 validate(CLAccessor(_target), _reference, tolerance_f16);
188}
189TEST_SUITE_END() // FP16
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000190TEST_SUITE_END() // Float
Gunes Bayir301e33f2024-04-29 17:00:14 +0100191
192TEST_SUITE(Integer)
193TEST_SUITE(S32)
194FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<int32_t>, framework::DatasetMode::PRECOMMIT,
195 combine(datasets::SmallScatterMixedDataset(),
196 make("DataType", {DataType::S32}),
197 allScatterFunctions,
198 make("ZeroInit", {false}),
199 make("Inplace", {false})))
200{
201 validate(CLAccessor(_target), _reference, tolerance_int);
202}
203TEST_SUITE_END() // S32
204
205TEST_SUITE(S16)
206FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<int16_t>, framework::DatasetMode::PRECOMMIT,
207 combine(datasets::SmallScatterMixedDataset(),
208 make("DataType", {DataType::S16}),
209 allScatterFunctions,
210 make("ZeroInit", {false}),
211 make("Inplace", {false})))
212{
213 validate(CLAccessor(_target), _reference, tolerance_int);
214}
215TEST_SUITE_END() // S16
216
217TEST_SUITE(S8)
218FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
219 combine(datasets::SmallScatterMixedDataset(),
220 make("DataType", {DataType::S8}),
221 allScatterFunctions,
222 make("ZeroInit", {false}),
223 make("Inplace", {false})))
224{
225 validate(CLAccessor(_target), _reference, tolerance_int);
226}
227TEST_SUITE_END() // S8
228
229TEST_SUITE(U32)
230FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<uint32_t>, framework::DatasetMode::PRECOMMIT,
231 combine(datasets::SmallScatterMixedDataset(),
232 make("DataType", {DataType::U32}),
233 allScatterFunctions,
234 make("ZeroInit", {false}),
235 make("Inplace", {false})))
236{
237 validate(CLAccessor(_target), _reference, tolerance_int);
238}
239TEST_SUITE_END() // U32
240
241TEST_SUITE(U16)
242FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<uint16_t>, framework::DatasetMode::PRECOMMIT,
243 combine(datasets::SmallScatterMixedDataset(),
244 make("DataType", {DataType::U16}),
245 allScatterFunctions,
246 make("ZeroInit", {false}),
247 make("Inplace", {false})))
248{
249 validate(CLAccessor(_target), _reference, tolerance_int);
250}
251TEST_SUITE_END() // U16
252
253TEST_SUITE(U8)
254FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
255 combine(datasets::SmallScatterMixedDataset(),
256 make("DataType", {DataType::U8}),
257 allScatterFunctions,
258 make("ZeroInit", {false}),
259 make("Inplace", {false})))
260{
261 validate(CLAccessor(_target), _reference, tolerance_int);
262}
263TEST_SUITE_END() // U8
264TEST_SUITE_END() // Integer
265
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000266TEST_SUITE_END() // Scatter
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000267TEST_SUITE_END() // CL
268} // namespace validation
269} // namespace test
270} // namespace arm_compute