blob: 5b1d5afe921481a01630cc7fb4e9464ae921716f [file] [log] [blame]
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +00001/*
2 * Copyright (c) 2024 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/CLTensor.h"
25#include "arm_compute/runtime/CL/functions/CLScatter.h"
26#include "tests/validation/fixtures/ScatterLayerFixture.h"
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000027#include "tests/datasets/ScatterDataset.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000028#include "tests/CL/CLAccessor.h"
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000029#include "arm_compute/function_info/ScatterInfo.h"
30#include "tests/framework/Asserts.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000031#include "tests/framework/Macros.h"
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000032#include "tests/framework/datasets/Datasets.h"
33#include "tests/validation/Validation.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000034
35namespace arm_compute
36{
37namespace test
38{
39namespace validation
40{
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000041namespace
42{
43RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for fp32 data type */
44} // namespace
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000045
46template <typename T>
47using CLScatterLayerFixture = ScatterValidationFixture<CLTensor, CLAccessor, CLScatter, T>;
48
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000049using framework::dataset::make;
50
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000051TEST_SUITE(CL)
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000052TEST_SUITE(Scatter)
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000053DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip(
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000054 make("InputInfo", { TensorInfo(TensorShape(9U), 1, DataType::F32), // Mismatching data types
55 TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid
56 TensorInfo(TensorShape(8U), 1, DataType::F32),
57 TensorInfo(TensorShape(217U), 1, DataType::F32), // Mismatch input/output dims.
58 TensorInfo(TensorShape(217U), 1, DataType::F32), // Updates dim higher than Input/Output dims.
59 TensorInfo(TensorShape(12U), 1, DataType::F32), // Indices wrong datatype.
60 }),
61 make("UpdatesInfo",{ TensorInfo(TensorShape(3U), 1, DataType::F16),
62 TensorInfo(TensorShape(15U), 1, DataType::F32),
63 TensorInfo(TensorShape(2U), 1, DataType::F32),
64 TensorInfo(TensorShape(217U), 1, DataType::F32),
65 TensorInfo(TensorShape(217U, 3U), 1, DataType::F32),
66 TensorInfo(TensorShape(2U), 1, DataType::F32),
67 }),
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000068 make("IndicesInfo",{ TensorInfo(TensorShape(1U, 3U), 1, DataType::S32),
69 TensorInfo(TensorShape(1U, 15U), 1, DataType::S32),
70 TensorInfo(TensorShape(1U, 2U), 1, DataType::S32),
71 TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
72 TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
73 TensorInfo(TensorShape(1U, 2U), 1 , DataType::F32)
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000074 }),
75 make("OutputInfo",{ TensorInfo(TensorShape(9U), 1, DataType::F16),
76 TensorInfo(TensorShape(15U), 1, DataType::F32),
77 TensorInfo(TensorShape(8U), 1, DataType::F32),
78 TensorInfo(TensorShape(271U, 3U), 1, DataType::F32),
79 TensorInfo(TensorShape(271U), 1, DataType::F32),
80 TensorInfo(TensorShape(12U), 1, DataType::F32)
81 }),
82 make("ScatterInfo",{ ScatterInfo(ScatterFunction::Add, false),
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000083 ScatterInfo(ScatterFunction::Max, false),
84 ScatterInfo(ScatterFunction::Min, false),
85 ScatterInfo(ScatterFunction::Add, false),
86 ScatterInfo(ScatterFunction::Update, false),
87 ScatterInfo(ScatterFunction::Sub, false),
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000088 }),
89 make("Expected", { false, true, true, false, false, false })),
90 input_info, updates_info, indices_info, output_info, scatter_info, expected)
91{
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +000092 const Status status = CLScatter::validate(&input_info.clone()->set_is_resizable(true), &updates_info.clone()->set_is_resizable(true), &indices_info.clone()->set_is_resizable(true), &output_info.clone()->set_is_resizable(true), scatter_info);
93 ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000094}
95
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000096TEST_SUITE(Float)
97TEST_SUITE(FP32)
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000098FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(),
99 make("DataType", {DataType::F32}),
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000100 make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max }),
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000101 make("ZeroInit", {false})))
102{
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000103 validate(CLAccessor(_target), _reference, tolerance_f32);
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000104}
105
106// With this test, src should be passed as nullptr.
107FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(),
108 make("DataType", {DataType::F32}),
109 make("ScatterFunction", {ScatterFunction::Add}),
110 make("ZeroInit", {true})))
111{
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000112 validate(CLAccessor(_target), _reference, tolerance_f32);
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000113}
Mohammed Suhail Munshi0e212362024-04-08 14:38:31 +0100114
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000115TEST_SUITE_END() // FP32
116TEST_SUITE_END() // Float
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000117TEST_SUITE_END() // Scatter
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000118TEST_SUITE_END() // CL
119} // namespace validation
120} // namespace test
121} // namespace arm_compute