blob: 2da2b9eabd4bfd5e61aff2d33bcf8bd676a0d246 [file] [log] [blame]
Jakub Sujak7359a872023-01-05 14:24:13 +00001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h"
26#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuMul.h"
27
28#include "tests/CL/CLAccessor.h"
29#include "tests/framework/Fixture.h"
30#include "tests/framework/Macros.h"
31#include "tests/framework/datasets/Datasets.h"
32#include "tests/validation/Validation.h"
33
34#include "tests/datasets/DynamicFusionDataset.h"
35#include "tests/datasets/ShapeDatasets.h"
36#include "tests/validation/fixtures/dynamic_fusion/operators/MulFixture.h"
37
38namespace arm_compute
39{
40namespace test
41{
42namespace validation
43{
44/* Synced with tests/validation/CL/PixelwiseMultiplication.cpp from the standard interface.
45 *
46 * Difference | Why the difference
47 * No integer tests | Not supported yet
48 * No quantized tests | Not supported yet
49 * No convert policy tests | Not needed as convert policy is ignored by floating types
50 * No scale tests | Not supported yet
51 * No rounding modes tests | Not supported yet
52 * No in place tests | Not supported yet
53 * No activation tests | Not needed in dynamic fusion interface
54 *
55 */
56namespace
57{
58constexpr AbsoluteTolerance<float> tolerance_f16(0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
59constexpr AbsoluteTolerance<float> tolerance_f32(0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
60} // namespace
61TEST_SUITE(CL)
62TEST_SUITE(DYNAMIC_FUSION)
63TEST_SUITE(MUL)
64
65// *INDENT-OFF*
66// clang-format off
67DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(
68 framework::dataset::make("LhsInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
69 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F16),
70 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Invalid data type combination
71 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Unsupported data type U8
72 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8), // Unsupported data type S8
73 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), // Unsupported data type S16
74 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S32), // Unsupported data type S32
75 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8), // Unsupported data type QASYMM8
76 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8_SIGNED), // Unsupported data type QASYMM8_SIGNED
77 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
78 TensorInfo(TensorShape(32U, 1U, 1U), 1, DataType::F32), // Broadcasting allowed for lhs
79 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
80 TensorInfo(TensorShape(15U, 23U, 3U), 1, DataType::F32), // Broadcast Y dimension is not allowed
81 TensorInfo(TensorShape( 3U, 8U, 9U), 1, DataType::F32), // Broadcast Z dimension is not allowed
82 TensorInfo(TensorShape(32U, 13U, 2U, 2), 1, DataType::F32), // Batching is allowed
83 }),
84 framework::dataset::make("RhsInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
85 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F16),
86 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F16),
87 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
88 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8),
89 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
90 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S32),
91 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
92 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8_SIGNED),
93 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
94 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
95 TensorInfo(TensorShape(32U, 1U, 1U), 1, DataType::F32), // Broadcasting allowed for rhs
96 TensorInfo(TensorShape(15U, 1U, 3U), 1, DataType::F32),
97 TensorInfo(TensorShape( 3U, 8U, 1U), 1, DataType::F32),
98 TensorInfo(TensorShape(32U, 13U, 2U, 2), 1, DataType::F32),
99 })),
100 framework::dataset::make("Expected", { true, true, false, false, false, false, false, false, false, false, true, true, false, false, true })),
101 input1_info, input2_info, expected)
102{
103 // Create a new workload sketch
104 auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
Viet-Hoa Do3fcf3dc2023-05-17 15:17:48 +0100105 auto context = GpuWorkloadContext{ &cl_compile_ctx };
106 GpuWorkloadSketch sketch{ &context };
Jakub Sujak7359a872023-01-05 14:24:13 +0000107
108 // Validate Elementwise Mul
Viet-Hoa Do3fcf3dc2023-05-17 15:17:48 +0100109 auto lhs_info = context.create_tensor_info(input1_info);
110 auto rhs_info = context.create_tensor_info(input2_info);
Jakub Sujak7359a872023-01-05 14:24:13 +0000111
112 bool res = bool(GpuMul::validate_op(sketch, &lhs_info, &rhs_info));
113 ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
114}
115// clang-format on
116// *INDENT-ON*
117
118template <typename T>
119using DynamicFusionCLMulFixture = DynamicFusionMulOneOpValidationFixture<CLTensor, CLAccessor, GpuMul, T>;
120template <typename T>
121using DynamicFusionCLMulBroadcastFixture = DynamicFusionMulBroadcastValidationFixture<CLTensor, CLAccessor, GpuMul, T>;
122template <typename T>
123using DynamicFusionCLMulTwoOpsFixture = DynamicFusionMulTwoOpsValidationFixture<CLTensor, CLAccessor, GpuMul, T>;
124
125TEST_SUITE(F16)
126FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
127 DynamicFusionCLMulFixture<half>,
128 framework::DatasetMode::ALL,
129 combine(combine(datasets::SmallShapes(),
130 framework::dataset::make("DataType", { DataType::F16 })),
131 framework::dataset::make("InPlace", { false })))
132{
133 // Validate output
134 validate(CLAccessor(_target), _reference, tolerance_f16);
135}
136
137FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp,
138 DynamicFusionCLMulBroadcastFixture<half>,
139 framework::DatasetMode::PRECOMMIT,
140 combine(combine(datasets::TemporaryLimitedSmallShapesBroadcast(),
141 framework::dataset::make("DataType", { DataType::F16 })),
142 framework::dataset::make("InPlace", { false })))
143{
144 // Validate output
145 validate(CLAccessor(_target), _reference, tolerance_f16);
146}
147
148FIXTURE_DATA_TEST_CASE(RunLargeBroadcastOneOp,
149 DynamicFusionCLMulBroadcastFixture<half>,
150 framework::DatasetMode::NIGHTLY,
151 combine(combine(datasets::TemporaryLimitedLargeShapesBroadcast(),
152 framework::dataset::make("DataType", { DataType::F16 })),
153 framework::dataset::make("InPlace", { false })))
154{
155 // Validate output
156 validate(CLAccessor(_target), _reference, tolerance_f16);
157}
158TEST_SUITE_END() // F16
159
160TEST_SUITE(F32)
161FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
162 DynamicFusionCLMulFixture<float>,
163 framework::DatasetMode::PRECOMMIT,
164 combine(combine(datasets::SmallShapes(),
165 framework::dataset::make("DataType", { DataType::F32 })),
166 framework::dataset::make("InPlace", { false })))
167{
168 // Validate output
169 validate(CLAccessor(_target), _reference, tolerance_f32);
170}
171
172FIXTURE_DATA_TEST_CASE(RunLargeOneOp,
173 DynamicFusionCLMulFixture<float>,
174 framework::DatasetMode::NIGHTLY,
175 combine(combine(datasets::LargeShapes(),
176 framework::dataset::make("DataType", { DataType::F32 })),
177 framework::dataset::make("InPlace", { false })))
178{
179 // Validate output
180 validate(CLAccessor(_target), _reference, tolerance_f32);
181}
182
183FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp,
184 DynamicFusionCLMulBroadcastFixture<float>,
185 framework::DatasetMode::PRECOMMIT,
186 combine(combine(datasets::TemporaryLimitedSmallShapesBroadcast(),
187 framework::dataset::make("DataType", { DataType::F32 })),
188 framework::dataset::make("InPlace", { false })))
189{
190 // Validate output
191 validate(CLAccessor(_target), _reference, tolerance_f32);
192}
193
194FIXTURE_DATA_TEST_CASE(RunLargeBroadcastOneOp,
195 DynamicFusionCLMulBroadcastFixture<float>,
196 framework::DatasetMode::NIGHTLY,
197 combine(combine(datasets::TemporaryLimitedLargeShapesBroadcast(),
198 framework::dataset::make("DataType", { DataType::F32 })),
199 framework::dataset::make("InPlace", { false })))
200{
201 // Validate output
202 validate(CLAccessor(_target), _reference, tolerance_f32);
203}
204
205FIXTURE_DATA_TEST_CASE(RunSmallTwoOps,
206 DynamicFusionCLMulTwoOpsFixture<float>,
207 framework::DatasetMode::PRECOMMIT,
208 combine(combine(combine(datasets::DynamicFusionElementwiseBinaryTwoOpsSmallShapes(),
209 framework::dataset::make("DataType", { DataType::F32 })),
210 framework::dataset::make("InPlace", { false })),
211 framework::dataset::make("FuseTwoOps", { true })))
212{
213 // Validate output
214 validate(CLAccessor(_target), _reference, tolerance_f32);
215}
216TEST_SUITE_END() // F32
217
218TEST_SUITE_END() // MUL
219TEST_SUITE_END() // DYNAMIC_FUSION
220TEST_SUITE_END() // CL
221} // namespace validation
222} // namespace test
223} // namespace arm_compute