blob: ebf2331e5e8038feb5d127922c60fd4c24f07217 [file] [log] [blame]
Pablo Tello4a626a72018-04-04 10:01:14 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/CL/kernels/CLIm2ColKernel.h"
25#include "arm_compute/core/Types.h"
26#include "tests/CL/Helper.h"
27
28#include "tests/CL/CLAccessor.h"
29#include "tests/datasets/ShapeDatasets.h"
30#include "tests/framework/Asserts.h"
31#include "tests/framework/Macros.h"
32#include "tests/framework/datasets/Datasets.h"
33#include "tests/validation/Validation.h"
34#include "tests/validation/fixtures/Im2ColFixture.h"
35
36namespace arm_compute
37{
38namespace test
39{
40namespace validation
41{
42namespace
43{
Gian Marco Iodice215b4ea2018-06-28 16:29:29 +010044// *INDENT-OFF*
45// clang-format off
46const auto conv_filter_sizes = framework::dataset::make("KernelDims", { Size2D(3U, 3U),
47 Size2D(5U, 5U),
48 Size2D(3U, 1U),
49 Size2D(1U, 3U),
50 Size2D(5U, 3U),
51 Size2D(1U, 1U),
52 Size2D(11U, 11U)} );
53const auto padstrides = framework::dataset::make("PadStride", { PadStrideInfo(1U, 1U, 0U, 0U),
54 PadStrideInfo(1U, 1U, 1U, 1U),
55 PadStrideInfo(2U, 2U, 0U, 2U) });
Giorgio Arena0f170392018-07-18 16:13:12 +010056const auto conv_args = combine(combine(combine(combine(conv_filter_sizes, padstrides),
57 framework::dataset::make("QuantizationInfo", QuantizationInfo(0.5f, 10))),
58 framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
59 framework::dataset::make("NumGroups", { 1 }));
60const auto grouped_args = combine(combine(combine(combine(conv_filter_sizes, padstrides),
61 framework::dataset::make("QuantizationInfo", QuantizationInfo(0.5f, 10))),
62 framework::dataset::make("DataLayout", { DataLayout::NCHW })),
63 framework::dataset::make("NumGroups", { 2, 3, 4 }));
Pablo Tello4a626a72018-04-04 10:01:14 +010064
65} // namespace
66TEST_SUITE(CL)
67TEST_SUITE(Im2Col)
68
69using CLIm2Col = CLSynthetizeFunction<CLIm2ColKernel>;
70
Pablo Tello4a626a72018-04-04 10:01:14 +010071DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
72 framework::dataset::make("InputInfo", { TensorInfo(TensorShape(10U, 12U, 2U), 1, DataType::U8), // Unsupported data type
73 TensorInfo(TensorShape(10U, 12U, 2U), 1, DataType::F32), // Mismatching data type
Pablo Tello4a626a72018-04-04 10:01:14 +010074 TensorInfo(TensorShape(10U, 12U, 2U), 1, DataType::QASYMM8), // Bias not supported with QASYMM8
Pablo Tello4a626a72018-04-04 10:01:14 +010075 TensorInfo(TensorShape(10U, 12U, 2U, 2U), 1, DataType::QASYMM8),
76 }),
77 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(3U, 4U, 10U, 2U), 1, DataType::F16),
78 TensorInfo(TensorShape(3U, 4U, 10U, 2U), 1, DataType::F16),
Pablo Tello4a626a72018-04-04 10:01:14 +010079 TensorInfo(TensorShape(3U, 3U, 10U, 2U), 1, DataType::QASYMM8),
Gian Marco Iodice215b4ea2018-06-28 16:29:29 +010080 TensorInfo(TensorShape(18U, 80U, 2U, 1U), 1, DataType::QASYMM8),
Pablo Tello4a626a72018-04-04 10:01:14 +010081 })),
Gian Marco Iodice215b4ea2018-06-28 16:29:29 +010082 framework::dataset::make("HasBias", { true, true, true, false })),
83 framework::dataset::make("Expected", { false, false, false, true })),
Pablo Tello4a626a72018-04-04 10:01:14 +010084 input_info, output_info, has_bias, expected)
85{
86
87 bool status = bool(CLIm2Col::validate(&input_info, &output_info, Size2D(3U, 3U), PadStrideInfo(), has_bias));
88 ARM_COMPUTE_EXPECT(status == expected, framework::LogLevel::ERRORS);
89}
90// clang-format on
91// *INDENT-ON*
92
93template <typename T>
Gian Marco Iodice597a8562018-08-01 15:06:06 +010094using CLIm2ColFixture = Im2ColValidationFixture<CLTensor, CLAccessor, CLIm2Col, T, true>;
Pablo Tello4a626a72018-04-04 10:01:14 +010095TEST_SUITE(Float)
96TEST_SUITE(FP32)
Giorgio Arenafb629082018-08-20 18:03:27 +010097FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<float>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)),
98 conv_args))
Pablo Tello4a626a72018-04-04 10:01:14 +010099{
100 // Validate output
101 validate(CLAccessor(_target), _reference);
102}
Pablo Tello4a626a72018-04-04 10:01:14 +0100103
Giorgio Arenafb629082018-08-20 18:03:27 +0100104FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)),
105 conv_args))
Pablo Tello4a626a72018-04-04 10:01:14 +0100106{
107 // Validate output
108 validate(CLAccessor(_target), _reference);
109}
Giorgio Arena0f170392018-07-18 16:13:12 +0100110TEST_SUITE_END()
Pablo Tello4a626a72018-04-04 10:01:14 +0100111
112TEST_SUITE(FP16)
Giorgio Arenafb629082018-08-20 18:03:27 +0100113FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<half>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)),
114 conv_args))
Pablo Tello4a626a72018-04-04 10:01:14 +0100115{
116 // Validate output
117 validate(CLAccessor(_target), _reference);
118}
Giorgio Arenafb629082018-08-20 18:03:27 +0100119FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)),
120 conv_args))
Pablo Tello4a626a72018-04-04 10:01:14 +0100121{
122 // Validate output
123 validate(CLAccessor(_target), _reference);
124}
125TEST_SUITE_END()
Pablo Tello4a626a72018-04-04 10:01:14 +0100126TEST_SUITE_END()
127
128TEST_SUITE(QASYMM8)
Giorgio Arenafb629082018-08-20 18:03:27 +0100129FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8)),
130 conv_args))
Pablo Tello4a626a72018-04-04 10:01:14 +0100131{
132 // Validate output
133 validate(CLAccessor(_target), _reference);
134}
Giorgio Arenafb629082018-08-20 18:03:27 +0100135FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::QASYMM8)),
136 conv_args))
Pablo Tello4a626a72018-04-04 10:01:14 +0100137{
138 // Validate output
139 validate(CLAccessor(_target), _reference);
140}
141TEST_SUITE_END()
142
Giorgio Arena0f170392018-07-18 16:13:12 +0100143TEST_SUITE(Grouped)
144TEST_SUITE(FP32)
Giorgio Arenafb629082018-08-20 18:03:27 +0100145FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<float>, framework::DatasetMode::ALL, combine(combine(datasets::GroupedIm2ColSmallShapes(), framework::dataset::make("DataType",
146 DataType::F32)),
147 grouped_args))
Giorgio Arena0f170392018-07-18 16:13:12 +0100148{
149 // Validate output
150 validate(CLAccessor(_target), _reference);
151}
152
Giorgio Arenafb629082018-08-20 18:03:27 +0100153FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::GroupedIm2ColLargeShapes(), framework::dataset::make("DataType",
154 DataType::F32)),
155 grouped_args))
Giorgio Arena0f170392018-07-18 16:13:12 +0100156{
157 // Validate output
158 validate(CLAccessor(_target), _reference);
159}
160TEST_SUITE_END()
161
162TEST_SUITE(FP16)
Giorgio Arenafb629082018-08-20 18:03:27 +0100163FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<half>, framework::DatasetMode::ALL, combine(combine(datasets::GroupedIm2ColSmallShapes(), framework::dataset::make("DataType",
164 DataType::F16)),
165 grouped_args))
Giorgio Arena0f170392018-07-18 16:13:12 +0100166{
167 // Validate output
168 validate(CLAccessor(_target), _reference);
169}
170
Giorgio Arenafb629082018-08-20 18:03:27 +0100171FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::GroupedIm2ColLargeShapes(), framework::dataset::make("DataType",
172 DataType::F16)),
173 grouped_args))
Giorgio Arena0f170392018-07-18 16:13:12 +0100174{
175 // Validate output
176 validate(CLAccessor(_target), _reference);
177}
178TEST_SUITE_END()
179
180TEST_SUITE(QASYMM8)
Giorgio Arenafb629082018-08-20 18:03:27 +0100181FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(datasets::GroupedIm2ColSmallShapes(), framework::dataset::make("DataType",
182 DataType::QASYMM8)),
183 grouped_args))
Giorgio Arena0f170392018-07-18 16:13:12 +0100184{
185 // Validate output
186 validate(CLAccessor(_target), _reference);
187}
188
Giorgio Arenafb629082018-08-20 18:03:27 +0100189FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::GroupedIm2ColLargeShapes(), framework::dataset::make("DataType",
190 DataType::QASYMM8)),
191 grouped_args))
Giorgio Arena0f170392018-07-18 16:13:12 +0100192{
193 // Validate output
194 validate(CLAccessor(_target), _reference);
195}
196TEST_SUITE_END()
197TEST_SUITE_END()
198
Pablo Tello4a626a72018-04-04 10:01:14 +0100199TEST_SUITE_END()
200TEST_SUITE_END()
201} // namespace validation
202} // namespace test
203} // namespace arm_compute