blob: b1531eb64ac61020c00420306d45d185546b95d1 [file] [log] [blame]
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +00001/*
2 * Copyright (c) 2024 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/CLTensor.h"
25#include "arm_compute/runtime/CL/functions/CLScatter.h"
26#include "tests/validation/fixtures/ScatterLayerFixture.h"
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000027#include "tests/datasets/ScatterDataset.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000028#include "tests/CL/CLAccessor.h"
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000029#include "arm_compute/function_info/ScatterInfo.h"
30#include "tests/framework/Asserts.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000031#include "tests/framework/Macros.h"
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000032#include "tests/framework/datasets/Datasets.h"
33#include "tests/validation/Validation.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000034
35namespace arm_compute
36{
37namespace test
38{
39namespace validation
40{
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000041namespace
42{
43RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for fp32 data type */
Gunes Bayir301e33f2024-04-29 17:00:14 +010044RelativeTolerance<float> tolerance_f16(0.02f); /**< Tolerance value for comparing reference's output against implementation's output for fp16 data type */
45RelativeTolerance<int32_t> tolerance_int(0); /**< Tolerance value for comparing reference's output against implementation's output for integer data types */
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000046} // namespace
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000047
48template <typename T>
49using CLScatterLayerFixture = ScatterValidationFixture<CLTensor, CLAccessor, CLScatter, T>;
50
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000051using framework::dataset::make;
52
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000053TEST_SUITE(CL)
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000054TEST_SUITE(Scatter)
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000055DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip(
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000056 make("InputInfo", { TensorInfo(TensorShape(9U), 1, DataType::F32), // Mismatching data types
Gunes Bayirada32002024-04-24 10:27:13 +010057 TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid
Gunes Bayir301e33f2024-04-29 17:00:14 +010058 TensorInfo(TensorShape(15U), 1, DataType::U8), // Valid
Gunes Bayirada32002024-04-24 10:27:13 +010059 TensorInfo(TensorShape(8U), 1, DataType::F32),
60 TensorInfo(TensorShape(217U), 1, DataType::F32), // Mismatch input/output dims.
61 TensorInfo(TensorShape(217U), 1, DataType::F32), // Updates dim higher than Input/Output dims.
62 TensorInfo(TensorShape(12U), 1, DataType::F32), // Indices wrong datatype.
63 TensorInfo(TensorShape(9U, 3U, 4U), 1, DataType::F32), // Number of updates != number of indices
64 TensorInfo(TensorShape(17U, 3U, 3U, 2U), 1, DataType::F32), // index_len != (dst_dims - upt_dims + 1)
65 TensorInfo(TensorShape(17U, 3U, 3U, 2U, 2U, 2U), 1, DataType::F32), // index_len > 5
66 }),
67 make("UpdatesInfo",{TensorInfo(TensorShape(3U), 1, DataType::F16),
68 TensorInfo(TensorShape(15U), 1, DataType::F32),
Gunes Bayir301e33f2024-04-29 17:00:14 +010069 TensorInfo(TensorShape(15U), 1, DataType::U8),
Gunes Bayirada32002024-04-24 10:27:13 +010070 TensorInfo(TensorShape(2U), 1, DataType::F32),
71 TensorInfo(TensorShape(217U), 1, DataType::F32),
72 TensorInfo(TensorShape(217U, 3U), 1, DataType::F32),
73 TensorInfo(TensorShape(2U), 1, DataType::F32),
74 TensorInfo(TensorShape(9U, 3U, 2U), 1, DataType::F32),
75 TensorInfo(TensorShape(17U, 3U, 2U), 1, DataType::F32),
76 TensorInfo(TensorShape(1U), 1, DataType::F32),
77 }),
78 make("IndicesInfo",{TensorInfo(TensorShape(1U, 3U), 1, DataType::S32),
79 TensorInfo(TensorShape(1U, 15U), 1, DataType::S32),
Gunes Bayir301e33f2024-04-29 17:00:14 +010080 TensorInfo(TensorShape(1U, 15U), 1, DataType::S32),
Gunes Bayirada32002024-04-24 10:27:13 +010081 TensorInfo(TensorShape(1U, 2U), 1, DataType::S32),
82 TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
83 TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
84 TensorInfo(TensorShape(1U, 2U), 1 , DataType::F32),
85 TensorInfo(TensorShape(1U, 4U), 1, DataType::S32),
86 TensorInfo(TensorShape(3U, 2U), 1, DataType::S32),
87 TensorInfo(TensorShape(6U, 2U), 1, DataType::S32),
88 }),
89 make("OutputInfo",{TensorInfo(TensorShape(9U), 1, DataType::F16),
90 TensorInfo(TensorShape(15U), 1, DataType::F32),
Gunes Bayir301e33f2024-04-29 17:00:14 +010091 TensorInfo(TensorShape(15U), 1, DataType::U8),
Gunes Bayirada32002024-04-24 10:27:13 +010092 TensorInfo(TensorShape(8U), 1, DataType::F32),
93 TensorInfo(TensorShape(271U, 3U), 1, DataType::F32),
94 TensorInfo(TensorShape(271U), 1, DataType::F32),
95 TensorInfo(TensorShape(12U), 1, DataType::F32),
96 TensorInfo(TensorShape(9U, 3U, 4U), 1, DataType::F32),
97 TensorInfo(TensorShape(17U, 3U, 3U, 2U), 1, DataType::F32),
98 TensorInfo(TensorShape(17U, 3U, 3U, 2U, 2U, 2U), 1, DataType::F32),
99 }),
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000100 make("ScatterInfo",{ ScatterInfo(ScatterFunction::Add, false),
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000101 ScatterInfo(ScatterFunction::Max, false),
Gunes Bayir301e33f2024-04-29 17:00:14 +0100102 ScatterInfo(ScatterFunction::Max, false),
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000103 ScatterInfo(ScatterFunction::Min, false),
104 ScatterInfo(ScatterFunction::Add, false),
105 ScatterInfo(ScatterFunction::Update, false),
106 ScatterInfo(ScatterFunction::Sub, false),
Gunes Bayirada32002024-04-24 10:27:13 +0100107 ScatterInfo(ScatterFunction::Sub, false),
108 ScatterInfo(ScatterFunction::Update, false),
109 ScatterInfo(ScatterFunction::Update, false),
110 }),
Gunes Bayir301e33f2024-04-29 17:00:14 +0100111 make("Expected", { false, true, true, true, false, false, false, false, false, false })),
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000112 input_info, updates_info, indices_info, output_info, scatter_info, expected)
113{
Gunes Bayirada32002024-04-24 10:27:13 +0100114 const Status status = CLScatter::validate(&input_info, &updates_info, &indices_info, &output_info, scatter_info);
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000115 ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000116}
117
Gunes Bayirada32002024-04-24 10:27:13 +0100118const auto allScatterFunctions = make("ScatterFunction",
119 {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max });
120
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000121TEST_SUITE(Float)
122TEST_SUITE(FP32)
Gunes Bayirada32002024-04-24 10:27:13 +0100123FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
124 combine(datasets::Small1DScatterDataset(),
125 make("DataType", {DataType::F32}),
126 allScatterFunctions,
127 make("ZeroInit", {false}),
Gunes Bayir05269f02024-05-09 13:24:15 +0100128 make("Inplace", {false}),
129 make("Padding", {true})))
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000130{
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000131 validate(CLAccessor(_target), _reference, tolerance_f32);
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000132}
133
134// With this test, src should be passed as nullptr.
Gunes Bayirada32002024-04-24 10:27:13 +0100135FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
136 combine(datasets::Small1DScatterDataset(),
137 make("DataType", {DataType::F32}),
138 make("ScatterFunction", {ScatterFunction::Add}),
139 make("ZeroInit", {true}),
Gunes Bayir05269f02024-05-09 13:24:15 +0100140 make("Inplace", {false}),
141 make("Padding", {true})))
Gunes Bayirada32002024-04-24 10:27:13 +0100142{
143 validate(CLAccessor(_target), _reference, tolerance_f32);
144}
145
146// Updates/src/dst have same no. dims.
147FIXTURE_DATA_TEST_CASE(RunSmallMultiDim, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
148 combine(datasets::SmallScatterMultiDimDataset(),
149 make("DataType", {DataType::F32}),
150 allScatterFunctions,
151 make("ZeroInit", {false}),
Gunes Bayir05269f02024-05-09 13:24:15 +0100152 make("Inplace", {false}),
153 make("Padding", {true})))
Gunes Bayirada32002024-04-24 10:27:13 +0100154{
155 validate(CLAccessor(_target), _reference, tolerance_f32);
156}
157
158// m+1-D to m+n-D cases
159FIXTURE_DATA_TEST_CASE(RunSmallMultiIndices, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
160 combine(datasets::SmallScatterMultiIndicesDataset(),
161 make("DataType", {DataType::F32}),
162 make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add }),
163 make("ZeroInit", {false}),
Gunes Bayir05269f02024-05-09 13:24:15 +0100164 make("Inplace", {false, true}),
165 make("Padding", {true})))
Gunes Bayirada32002024-04-24 10:27:13 +0100166{
167 validate(CLAccessor(_target), _reference, tolerance_f32);
168}
169
170// m+k, k-1-D m+n-D case
Mohammed Suhail Munshi2fea1352024-04-29 22:53:58 +0100171FIXTURE_DATA_TEST_CASE(RunSmallBatchedMultiIndices, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
Gunes Bayirada32002024-04-24 10:27:13 +0100172 combine(datasets::SmallScatterBatchedDataset(),
173 make("DataType", {DataType::F32}),
Mohammed Suhail Munshi2fea1352024-04-29 22:53:58 +0100174 make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add}),
Gunes Bayirada32002024-04-24 10:27:13 +0100175 make("ZeroInit", {false}),
Gunes Bayir05269f02024-05-09 13:24:15 +0100176 make("Inplace", {false}),
177 make("Padding", {true})))
178{
179 validate(CLAccessor(_target), _reference, tolerance_f32);
180}
181
182// m+k, k-1-D m+n-D case
183FIXTURE_DATA_TEST_CASE(RunSmallScatterScalar, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
184 combine(datasets::SmallScatterScalarDataset(),
185 make("DataType", {DataType::F32}),
186 make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add}),
187 make("ZeroInit", {false}),
188 make("Inplace", {false}),
189 make("Padding", {false}))) // NOTE: Padding not supported in this datset
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000190{
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000191 validate(CLAccessor(_target), _reference, tolerance_f32);
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000192}
Mohammed Suhail Munshi0e212362024-04-08 14:38:31 +0100193
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000194TEST_SUITE_END() // FP32
Gunes Bayir301e33f2024-04-29 17:00:14 +0100195
Gunes Bayir05269f02024-05-09 13:24:15 +0100196
197// NOTE: Padding is disabled for the SmallScatterMixedDataset due certain shapes not supporting padding.
198// Padding is well tested in F32 Datatype test cases.
199
Gunes Bayir301e33f2024-04-29 17:00:14 +0100200TEST_SUITE(FP16)
201FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<half>, framework::DatasetMode::PRECOMMIT,
202 combine(datasets::SmallScatterMixedDataset(),
203 make("DataType", {DataType::F16}),
204 allScatterFunctions,
205 make("ZeroInit", {false}),
Gunes Bayir05269f02024-05-09 13:24:15 +0100206 make("Inplace", {false}),
207 make("Padding", {false})))
Gunes Bayir301e33f2024-04-29 17:00:14 +0100208{
209 validate(CLAccessor(_target), _reference, tolerance_f16);
210}
211TEST_SUITE_END() // FP16
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000212TEST_SUITE_END() // Float
Gunes Bayir301e33f2024-04-29 17:00:14 +0100213
214TEST_SUITE(Integer)
215TEST_SUITE(S32)
216FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<int32_t>, framework::DatasetMode::PRECOMMIT,
217 combine(datasets::SmallScatterMixedDataset(),
218 make("DataType", {DataType::S32}),
219 allScatterFunctions,
220 make("ZeroInit", {false}),
Gunes Bayir05269f02024-05-09 13:24:15 +0100221 make("Inplace", {false}),
222 make("Padding", {false})))
Gunes Bayir301e33f2024-04-29 17:00:14 +0100223{
224 validate(CLAccessor(_target), _reference, tolerance_int);
225}
226TEST_SUITE_END() // S32
227
228TEST_SUITE(S16)
229FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<int16_t>, framework::DatasetMode::PRECOMMIT,
230 combine(datasets::SmallScatterMixedDataset(),
231 make("DataType", {DataType::S16}),
232 allScatterFunctions,
233 make("ZeroInit", {false}),
Gunes Bayir05269f02024-05-09 13:24:15 +0100234 make("Inplace", {false}),
235 make("Padding", {false})))
Gunes Bayir301e33f2024-04-29 17:00:14 +0100236{
237 validate(CLAccessor(_target), _reference, tolerance_int);
238}
239TEST_SUITE_END() // S16
240
241TEST_SUITE(S8)
242FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
243 combine(datasets::SmallScatterMixedDataset(),
244 make("DataType", {DataType::S8}),
245 allScatterFunctions,
246 make("ZeroInit", {false}),
Gunes Bayir05269f02024-05-09 13:24:15 +0100247 make("Inplace", {false}),
248 make("Padding", {false})))
Gunes Bayir301e33f2024-04-29 17:00:14 +0100249{
250 validate(CLAccessor(_target), _reference, tolerance_int);
251}
252TEST_SUITE_END() // S8
253
254TEST_SUITE(U32)
255FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<uint32_t>, framework::DatasetMode::PRECOMMIT,
256 combine(datasets::SmallScatterMixedDataset(),
257 make("DataType", {DataType::U32}),
258 allScatterFunctions,
259 make("ZeroInit", {false}),
Gunes Bayir05269f02024-05-09 13:24:15 +0100260 make("Inplace", {false}),
261 make("Padding", {false})))
Gunes Bayir301e33f2024-04-29 17:00:14 +0100262{
263 validate(CLAccessor(_target), _reference, tolerance_int);
264}
265TEST_SUITE_END() // U32
266
267TEST_SUITE(U16)
268FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<uint16_t>, framework::DatasetMode::PRECOMMIT,
269 combine(datasets::SmallScatterMixedDataset(),
270 make("DataType", {DataType::U16}),
271 allScatterFunctions,
272 make("ZeroInit", {false}),
Gunes Bayir05269f02024-05-09 13:24:15 +0100273 make("Inplace", {false}),
274 make("Padding", {false})))
Gunes Bayir301e33f2024-04-29 17:00:14 +0100275{
276 validate(CLAccessor(_target), _reference, tolerance_int);
277}
278TEST_SUITE_END() // U16
279
280TEST_SUITE(U8)
281FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
282 combine(datasets::SmallScatterMixedDataset(),
283 make("DataType", {DataType::U8}),
284 allScatterFunctions,
285 make("ZeroInit", {false}),
Gunes Bayir05269f02024-05-09 13:24:15 +0100286 make("Inplace", {false}),
287 make("Padding", {false})))
Gunes Bayir301e33f2024-04-29 17:00:14 +0100288{
289 validate(CLAccessor(_target), _reference, tolerance_int);
290}
291TEST_SUITE_END() // U8
292TEST_SUITE_END() // Integer
293
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000294TEST_SUITE_END() // Scatter
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000295TEST_SUITE_END() // CL
296} // namespace validation
297} // namespace test
298} // namespace arm_compute