blob: 1451ab3de8df812f1d1bdb5cc084920fb8a2c543 [file] [log] [blame]
Ramy Elgammal404462a2022-11-08 02:14:46 +00001/*
Viet-Hoa Dob3077fb2023-01-03 17:59:14 +00002 * Copyright (c) 2022-2023 Arm Limited.
Ramy Elgammal404462a2022-11-08 02:14:46 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h"
26#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuAdd.h"
27
28#include "tests/CL/CLAccessor.h"
29#include "tests/framework/Fixture.h"
30#include "tests/framework/Macros.h"
31#include "tests/framework/datasets/Datasets.h"
32#include "tests/validation/Validation.h"
33
34#include "tests/datasets/DynamicFusionDataset.h"
35#include "tests/datasets/ShapeDatasets.h"
36#include "tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h"
37#include "tests/validation/reference/ElementwiseOperations.h"
38
39namespace arm_compute
40{
41namespace test
42{
43namespace validation
44{
45TEST_SUITE(CL)
46TEST_SUITE(DYNAMIC_FUSION)
47TEST_SUITE(ADD)
48
49// *INDENT-OFF*
50// clang-format off
51DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
52 framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
53 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Invalid data type combination
54 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), // S16 is valid data type for Add
55 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S32), // S32 is valid data type for Add
56 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
Viet-Hoa Dob3077fb2023-01-03 17:59:14 +000057 TensorInfo(TensorShape(32U, 1U, 1U), 1, DataType::F32), // Broadcasting allowed for lhs
Ramy Elgammal404462a2022-11-08 02:14:46 +000058 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
Viet-Hoa Dob3077fb2023-01-03 17:59:14 +000059 TensorInfo(TensorShape(15U, 23U, 3U), 1, DataType::F32), // Broadcast Y dimension is not allowed
60 TensorInfo(TensorShape( 3U, 8U, 9U), 1, DataType::S16), // Broadcast Z dimension is not allowed
61 TensorInfo(TensorShape(32U, 13U, 2U, 2), 1, DataType::F32), // Batching is allowed
Ramy Elgammal404462a2022-11-08 02:14:46 +000062 }),
63 framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
64 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F16),
65 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
66 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S32),
67 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
68 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
69 TensorInfo(TensorShape(32U, 1U, 1U), 1, DataType::F32), // Broadcasting allowed for rhs
Viet-Hoa Dob3077fb2023-01-03 17:59:14 +000070 TensorInfo(TensorShape(15U, 1U, 3U), 1, DataType::F32),
71 TensorInfo(TensorShape( 3U, 8U, 1U), 1, DataType::S16),
72 TensorInfo(TensorShape(32U, 13U, 2U, 2), 1, DataType::F32),
Ramy Elgammal404462a2022-11-08 02:14:46 +000073 })),
74 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
75 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
76 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
77 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S32),
78 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
79 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
80 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
Viet-Hoa Dob3077fb2023-01-03 17:59:14 +000081 TensorInfo(TensorShape(15U, 23U, 3U), 1, DataType::F32),
82 TensorInfo(TensorShape( 3U, 8U, 9U), 1, DataType::S16),
Ramy Elgammal404462a2022-11-08 02:14:46 +000083 TensorInfo(TensorShape(32U, 13U, 2U, 2), 1, DataType::F32),
84 })),
Viet-Hoa Dob3077fb2023-01-03 17:59:14 +000085 framework::dataset::make("Expected", { true, false, true, true, false, true, true, false, false, true})),
Ramy Elgammal404462a2022-11-08 02:14:46 +000086 input1_info, input2_info, output_info, expected)
87{
88 // Create a new workload sketch
89 auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
90 auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx };
91 GpuWorkloadSketch sketch{ &gpu_ctx };
92
93 // Fuse Elementwise Add
94 auto lhs_info = sketch.create_tensor_info(input1_info);
95 auto rhs_info = sketch.create_tensor_info(input2_info);
96 auto dst_info = sketch.create_tensor_info(output_info);
97 bool res = bool(GpuAdd::validate_op(sketch, &lhs_info, &rhs_info, &dst_info));
98 ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
99}
100
101DATA_TEST_CASE(ValidateRhsInplace, framework::DatasetMode::ALL, zip(zip(
102 framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 1U, 1U), 1, DataType::F32), // Broadcasting allowed for lhs
103 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
104 }),
105 framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
106 TensorInfo(TensorShape(32U, 1U, 1U), 1, DataType::F32), // Broadcasting not allowed for rhs
107 })),
108 framework::dataset::make("Expected", { true, false})),
109 input1_info, input2_info, expected)
110{
111 // Create a new workload sketch
112 auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
113 auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx };
114 GpuWorkloadSketch sketch{ &gpu_ctx };
115
116 // Fuse Elementwise Add
117 auto lhs_info = sketch.create_tensor_info(input1_info);
118 auto rhs_info = sketch.create_tensor_info(input2_info);
119 bool res = bool(GpuAdd::validate_op(sketch, &lhs_info, &rhs_info, &rhs_info));
120 ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
121}
122
123DATA_TEST_CASE(ValidateLhsInplace, framework::DatasetMode::ALL, zip(zip(
124 framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 1U, 1U), 1, DataType::F32), // Broadcasting not allowed for lhs
125 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
126 }),
127 framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
128 TensorInfo(TensorShape(32U, 1U, 1U), 1, DataType::F32), // Broadcasting allowed for rhs
129 })),
130 framework::dataset::make("Expected", { false, true})),
131 input1_info, input2_info, expected)
132{
133 // Create a new workload sketch
134 auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
135 auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx };
136 GpuWorkloadSketch sketch{ &gpu_ctx };
137
138 // Fuse Elementwise Add
139 auto lhs_info = sketch.create_tensor_info(input1_info);
140 auto rhs_info = sketch.create_tensor_info(input2_info);
141 bool res = bool(GpuAdd::validate_op(sketch, &lhs_info, &rhs_info, &lhs_info));
142 ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
143}
144// clang-format on
145// *INDENT-ON*
146
147RelativeTolerance<float> tolerance_f32(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
148RelativeTolerance<half_float::half> tolerance_f16(half_float::half(0.1)); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
149constexpr float tolerance_num = 0.01f; /**< Tolerance number */
150
151template <typename T>
152using DynamicFusionAddOpFixture = DynamicFusionGpuElementwiseBinaryOneOpValidationFixture<CLTensor, CLAccessor, GpuAdd, T>;
153
154template <typename T>
155using DynamicFusionAddOpBroadcastFixture = DynamicFusionGpuElementwiseBinaryBroadcastOneOpValidationFixture<CLTensor, CLAccessor, GpuAdd, T>;
156
157template <typename T>
158using DynamicFusionGpuFuseTwoAddOpsFixture = DynamicFusionGpuElementwiseBinaryTwoOpsValidationFixture<CLTensor, CLAccessor, GpuAdd, T>;
159
160TEST_SUITE(FP32)
161FIXTURE_DATA_TEST_CASE(RunSmallOneOp, DynamicFusionAddOpFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(
162 framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
163 datasets::SmallShapesNoBatches()),
164 framework::dataset::make("DataType", { DataType::F32 })),
165 framework::dataset::make("InPlace", { false, true })))
166{
167 // Validate output
168 validate(CLAccessor(_target), _reference, tolerance_f32);
169}
170FIXTURE_DATA_TEST_CASE(RunLargeOneOp, DynamicFusionAddOpFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(
171 framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
172 datasets::LargeShapesNoBatches()),
173 framework::dataset::make("DataType", { DataType::F32 })),
174 framework::dataset::make("InPlace", { false, true })))
175{
176 // Validate output
177 validate(CLAccessor(_target), _reference, tolerance_f32);
178}
179FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp, DynamicFusionAddOpBroadcastFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
180 datasets::TemporaryLimitedSmallShapesBroadcast()),
181 framework::dataset::make("DataType", { DataType::F32 })),
182 framework::dataset::make("InPlace", { false, true })))
183{
184 // Validate output
185 validate(CLAccessor(_target), _reference, tolerance_f32);
186}
187
188FIXTURE_DATA_TEST_CASE(RunLargeBroadcastOneOp, DynamicFusionAddOpBroadcastFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
189 datasets::TemporaryLimitedLargeShapesBroadcast()),
190 framework::dataset::make("DataType", { DataType::F32 })),
191 framework::dataset::make("InPlace", { false, true })))
192{
193 // Validate output
194 validate(CLAccessor(_target), _reference, tolerance_f32);
195}
196FIXTURE_DATA_TEST_CASE(RunSmallTwoOps, DynamicFusionGpuFuseTwoAddOpsFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
197 datasets::DynamicFusionElementwiseBinaryTwoOpsSmallShapes()),
198 framework::dataset::make("DataType", { DataType::F32 })),
199 framework::dataset::make("InPlace", { false })))
200{
201 // Validate output
202 validate(CLAccessor(_target), _reference, tolerance_f32);
203}
204TEST_SUITE_END() // FP32
205
206TEST_SUITE(FP16)
207FIXTURE_DATA_TEST_CASE(RunSmallOneOp, DynamicFusionAddOpFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
208 datasets::SmallShapesNoBatches()),
209 framework::dataset::make("DataType", { DataType::F16 })),
210 framework::dataset::make("InPlace", { false, true })))
211{
212 // Validate output
213 validate(CLAccessor(_target), _reference, tolerance_f32, tolerance_num);
214}
215
216FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp, DynamicFusionAddOpBroadcastFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
217 datasets::TemporaryLimitedSmallShapesBroadcast()),
218 framework::dataset::make("DataType", { DataType::F16 })),
219 framework::dataset::make("InPlace", { false })))
220{
221 // Validate output
222 validate(CLAccessor(_target), _reference, tolerance_f32, tolerance_num);
223}
224
225TEST_SUITE_END() // FP16
226
227TEST_SUITE(S32)
228FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionAddOpFixture<int32_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
229 datasets::SmallShapesNoBatches()),
230 framework::dataset::make("DataType", { DataType::S32 })),
231 framework::dataset::make("InPlace", { false })))
232{
233 // Validate output
234 validate(CLAccessor(_target), _reference);
235}
236TEST_SUITE_END() // S32
237
238TEST_SUITE(S16)
239FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionAddOpFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
240 datasets::SmallShapesNoBatches()),
241 framework::dataset::make("DataType", { DataType::S16 })),
242 framework::dataset::make("InPlace", { false })))
243{
244 // Validate output
245 validate(CLAccessor(_target), _reference);
246}
247FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionAddOpFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
248 datasets::LargeShapesNoBatches()),
249 framework::dataset::make("DataType", { DataType::S16 })),
250 framework::dataset::make("InPlace", { false })))
251{
252 // Validate output
253 validate(CLAccessor(_target), _reference);
254}
255TEST_SUITE_END() // S16
256
257TEST_SUITE(U8)
258FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionAddOpFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
259 datasets::SmallShapesNoBatches()),
260 framework::dataset::make("DataType", { DataType::U8 })),
261 framework::dataset::make("InPlace", { false })))
262{
263 // Validate output
264 validate(CLAccessor(_target), _reference);
265}
266TEST_SUITE_END() // U8
267
268TEST_SUITE_END() // ADD
269TEST_SUITE_END() // DYNAMIC_FUSION
270TEST_SUITE_END() // CL
271} // namespace validation
272} // namespace test
273} // namespace arm_compute