blob: 56338f489f7067c0f216920409ce316890f46afa [file] [log] [blame]
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +00001/*
2 * Copyright (c) 2024 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/CLTensor.h"
25#include "arm_compute/runtime/CL/functions/CLScatter.h"
26#include "tests/validation/fixtures/ScatterLayerFixture.h"
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000027#include "tests/datasets/ScatterDataset.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000028#include "tests/CL/CLAccessor.h"
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000029#include "arm_compute/function_info/ScatterInfo.h"
30#include "tests/framework/Asserts.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000031#include "tests/framework/Macros.h"
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000032#include "tests/framework/datasets/Datasets.h"
33#include "tests/validation/Validation.h"
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000034
35namespace arm_compute
36{
37namespace test
38{
39namespace validation
40{
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000041
42template <typename T>
43using CLScatterLayerFixture = ScatterValidationFixture<CLTensor, CLAccessor, CLScatter, T>;
44
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000045using framework::dataset::make;
46
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000047TEST_SUITE(CL)
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000048TEST_SUITE(Scatter)
49DATA_TEST_CASE(Validate, framework::DatasetMode::DISABLED, zip(
50 make("InputInfo", { TensorInfo(TensorShape(9U), 1, DataType::F32), // Mismatching data types
51 TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid
52 TensorInfo(TensorShape(8U), 1, DataType::F32),
53 TensorInfo(TensorShape(217U), 1, DataType::F32), // Mismatch input/output dims.
54 TensorInfo(TensorShape(217U), 1, DataType::F32), // Updates dim higher than Input/Output dims.
55 TensorInfo(TensorShape(12U), 1, DataType::F32), // Indices wrong datatype.
56 }),
57 make("UpdatesInfo",{ TensorInfo(TensorShape(3U), 1, DataType::F16),
58 TensorInfo(TensorShape(15U), 1, DataType::F32),
59 TensorInfo(TensorShape(2U), 1, DataType::F32),
60 TensorInfo(TensorShape(217U), 1, DataType::F32),
61 TensorInfo(TensorShape(217U, 3U), 1, DataType::F32),
62 TensorInfo(TensorShape(2U), 1, DataType::F32),
63 }),
64 make("IndicesInfo",{ TensorInfo(TensorShape(3U), 1, DataType::U32),
65 TensorInfo(TensorShape(15U), 1, DataType::U32),
66 TensorInfo(TensorShape(2U), 1, DataType::U32),
67 TensorInfo(TensorShape(271U), 1, DataType::U32),
68 TensorInfo(TensorShape(271U), 1, DataType::U32),
69 TensorInfo(TensorShape(2U), 1 , DataType::S32)
70 }),
71 make("OutputInfo",{ TensorInfo(TensorShape(9U), 1, DataType::F16),
72 TensorInfo(TensorShape(15U), 1, DataType::F32),
73 TensorInfo(TensorShape(8U), 1, DataType::F32),
74 TensorInfo(TensorShape(271U, 3U), 1, DataType::F32),
75 TensorInfo(TensorShape(271U), 1, DataType::F32),
76 TensorInfo(TensorShape(12U), 1, DataType::F32)
77 }),
78 make("ScatterInfo",{ ScatterInfo(ScatterFunction::Add, false),
79 }),
80 make("Expected", { false, true, true, false, false, false })),
81 input_info, updates_info, indices_info, output_info, scatter_info, expected)
82{
83 // TODO: Enable validation tests.
84 ARM_COMPUTE_UNUSED(input_info);
85 ARM_COMPUTE_UNUSED(updates_info);
86 ARM_COMPUTE_UNUSED(indices_info);
87 ARM_COMPUTE_UNUSED(output_info);
88 ARM_COMPUTE_UNUSED(scatter_info);
89 ARM_COMPUTE_UNUSED(expected);
90}
91
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +000092TEST_SUITE(Float)
93TEST_SUITE(FP32)
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +000094FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(),
95 make("DataType", {DataType::F32}),
96 make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max}),
97 make("ZeroInit", {false})))
98{
99 // TODO: Add validate() here.
100}
101
102// With this test, src should be passed as nullptr.
103FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(),
104 make("DataType", {DataType::F32}),
105 make("ScatterFunction", {ScatterFunction::Add}),
106 make("ZeroInit", {true})))
107{
108 // TODO: Add validate() here
109}
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000110TEST_SUITE_END() // FP32
111TEST_SUITE_END() // Float
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000112TEST_SUITE_END() // Scatter
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000113TEST_SUITE_END() // CL
114} // namespace validation
115} // namespace test
116} // namespace arm_compute