blob: e1908abe2f6ee7e02a9865e8db9e142b9c378b53 [file] [log] [blame]
Michalis Spyrou5237e012018-01-17 09:40:27 +00001/*
shubhame1a4e372019-01-07 21:37:55 +05302 * Copyright (c) 2018-2019 ARM Limited.
Michalis Spyrou5237e012018-01-17 09:40:27 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/Types.h"
25#include "arm_compute/runtime/CL/CLTensor.h"
26#include "arm_compute/runtime/CL/CLTensorAllocator.h"
27#include "arm_compute/runtime/CL/functions/CLPermute.h"
28#include "tests/CL/CLAccessor.h"
29#include "tests/PaddingCalculator.h"
30#include "tests/datasets/ShapeDatasets.h"
31#include "tests/framework/Asserts.h"
32#include "tests/framework/Macros.h"
33#include "tests/framework/datasets/Datasets.h"
34#include "tests/validation/Validation.h"
35#include "tests/validation/fixtures/PermuteFixture.h"
36
37namespace arm_compute
38{
39namespace test
40{
41namespace validation
42{
43namespace
44{
shubhame1a4e372019-01-07 21:37:55 +053045const auto PermuteVectors3 = framework::dataset::make("PermutationVector",
Pablo Tello35767bc2018-12-05 17:36:30 +000046{
47 PermutationVector(2U, 0U, 1U),
48 PermutationVector(1U, 2U, 0U),
shubhame1a4e372019-01-07 21:37:55 +053049 PermutationVector(0U, 1U, 2U),
50 PermutationVector(0U, 2U, 1U),
51 PermutationVector(1U, 0U, 2U),
52 PermutationVector(2U, 1U, 0U),
Pablo Tello35767bc2018-12-05 17:36:30 +000053});
shubhame1a4e372019-01-07 21:37:55 +053054const auto PermuteVectors4 = framework::dataset::make("PermutationVector",
55{
56 PermutationVector(3U, 2U, 0U, 1U),
57 PermutationVector(3U, 2U, 1U, 0U),
58 PermutationVector(2U, 3U, 1U, 0U),
59 PermutationVector(1U, 3U, 2U, 0U),
60 PermutationVector(3U, 1U, 2U, 0U),
61 PermutationVector(3U, 0U, 2U, 1U),
62 PermutationVector(0U, 3U, 2U, 1U)
63});
64const auto PermuteVectors = concat(PermuteVectors3, PermuteVectors4);
Pablo Tello35767bc2018-12-05 17:36:30 +000065const auto PermuteInputLayout = framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC });
66const auto PermuteParametersSmall = concat(concat(datasets::Small2DShapes(), datasets::Small3DShapes()), datasets::Small4DShapes()) * PermuteInputLayout * PermuteVectors;
67const auto PermuteParametersLarge = datasets::Large4DShapes() * PermuteInputLayout * PermuteVectors;
Michalis Spyrou5237e012018-01-17 09:40:27 +000068} // namespace
69TEST_SUITE(CL)
70TEST_SUITE(Permute)
71
Isabella Gottardiaad9f2c2018-02-21 11:51:23 +000072// *INDENT-OFF*
73// clang-format off
74DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
shubhame1a4e372019-01-07 21:37:55 +053075 framework::dataset::make("InputInfo",{
76 TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16), // valid
77 TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16), // permutation not supported
78 TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16), // permutation not supported
79 TensorInfo(TensorShape(1U, 7U), 1, DataType::U8), // invalid input size
80 TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16), // valid
81 TensorInfo(TensorShape(27U, 13U, 37U, 2U), 1, DataType::F32), // valid
82 TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16), // permutation not supported
83 TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::S16), // valid
84 TensorInfo(TensorShape(27U, 13U, 37U, 2U), 1, DataType::F32), // permutation not supported
85 TensorInfo(TensorShape(27U, 13U, 37U, 2U), 1, DataType::F32), // valid
86 TensorInfo(TensorShape(27U, 13U, 37U, 2U), 1, DataType::F32) // permutation not supported
87
88 }),
89 framework::dataset::make("OutputInfo", {
90 TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16),
91 TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16),
92 TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16),
93 TensorInfo(TensorShape(5U, 7U), 1, DataType::U8),
94 TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16),
95 TensorInfo(TensorShape(13U, 37U, 27U, 2U), 1, DataType::F32),
96 TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16),
97 TensorInfo(TensorShape(3U, 5U, 7U, 7U), 1, DataType::S16),
98 TensorInfo(TensorShape(13U, 37U, 27U, 2U), 1, DataType::F32),
99 TensorInfo(TensorShape(37U, 2U, 13U, 27U), 1, DataType::F32),
100 TensorInfo(TensorShape(37U, 2U, 13U, 27U), 1, DataType::F32)
101
102 })),
103 framework::dataset::make("PermutationVector", {
104 PermutationVector(2U, 1U, 0U),
105 PermutationVector(2U, 2U, 1U),
106 PermutationVector(1U, 1U, 1U),
107 PermutationVector(2U, 0U, 1U),
108 PermutationVector(2U, 0U, 1U),
109 PermutationVector(1U, 2U, 0U),
110 PermutationVector(3U, 2U, 0U, 1U),
111 PermutationVector(3U, 2U, 0U, 1U),
112 PermutationVector(2U, 3U, 1U, 0U),
113 PermutationVector(2U, 3U, 1U, 0U),
114 PermutationVector(0U, 0U, 0U, 1000U)
115 })),
116 framework::dataset::make("Expected", { true, false, false, false, true, true, false, true, false, true, false })),
117 input_info, output_info, perm_vect, expected)
Isabella Gottardiaad9f2c2018-02-21 11:51:23 +0000118{
119 ARM_COMPUTE_EXPECT(bool(CLPermute::validate(&input_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), perm_vect)) == expected, framework::LogLevel::ERRORS);
120}
121// clang-format on
122// *INDENT-ON*
123
Michalis Spyrou5237e012018-01-17 09:40:27 +0000124DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::Small4DShapes(), framework::dataset::make("DataType", { DataType::S8, DataType::U8, DataType::S16, DataType::U16, DataType::U32, DataType::S32, DataType::F16, DataType::F32 })),
125 shape, data_type)
126{
127 // Define permutation vector
128 const PermutationVector perm(2U, 0U, 1U);
129
130 // Permute shapes
131 TensorShape output_shape = shape;
132 permute(output_shape, perm);
133
134 // Create tensors
135 CLTensor ref_src = create_tensor<CLTensor>(shape, data_type);
136 CLTensor dst = create_tensor<CLTensor>(output_shape, data_type);
137
138 // Create and Configure function
139 CLPermute perm_func;
140 perm_func.configure(&ref_src, &dst, perm);
141
142 // Validate valid region
143 const ValidRegion valid_region = shape_to_valid_region(output_shape);
144 validate(dst.info()->valid_region(), valid_region);
145}
146
shubhame1a4e372019-01-07 21:37:55 +0530147#ifndef DOXYGEN_SKIP_THIS
148
Michalis Spyrou5237e012018-01-17 09:40:27 +0000149template <typename T>
150using CLPermuteFixture = PermuteValidationFixture<CLTensor, CLAccessor, CLPermute, T>;
151
152TEST_SUITE(U8)
shubhame1a4e372019-01-07 21:37:55 +0530153FIXTURE_DATA_TEST_CASE(RunSmall, CLPermuteFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
154 PermuteParametersSmall * framework::dataset::make("DataType", DataType::U8))
Michalis Spyrou5237e012018-01-17 09:40:27 +0000155{
156 // Validate output
157 validate(CLAccessor(_target), _reference);
158}
shubhame1a4e372019-01-07 21:37:55 +0530159
160FIXTURE_DATA_TEST_CASE(RunLarge, CLPermuteFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
161 PermuteParametersLarge * framework::dataset::make("DataType", DataType::U8))
Michalis Spyrou5237e012018-01-17 09:40:27 +0000162{
163 // Validate output
164 validate(CLAccessor(_target), _reference);
165}
shubhame1a4e372019-01-07 21:37:55 +0530166TEST_SUITE_END() // U8
Michalis Spyrou5237e012018-01-17 09:40:27 +0000167
168TEST_SUITE(U16)
shubhame1a4e372019-01-07 21:37:55 +0530169FIXTURE_DATA_TEST_CASE(RunSmall, CLPermuteFixture<uint16_t>, framework::DatasetMode::PRECOMMIT,
170 PermuteParametersSmall * framework::dataset::make("DataType", DataType::U16))
Michalis Spyrou5237e012018-01-17 09:40:27 +0000171{
172 // Validate output
173 validate(CLAccessor(_target), _reference);
174}
shubhame1a4e372019-01-07 21:37:55 +0530175FIXTURE_DATA_TEST_CASE(RunLarge, CLPermuteFixture<uint16_t>, framework::DatasetMode::NIGHTLY,
176 PermuteParametersLarge * framework::dataset::make("DataType", DataType::U16))
Michalis Spyrou5237e012018-01-17 09:40:27 +0000177{
178 // Validate output
179 validate(CLAccessor(_target), _reference);
180}
shubhame1a4e372019-01-07 21:37:55 +0530181TEST_SUITE_END() // U16
Michalis Spyrou5237e012018-01-17 09:40:27 +0000182
183TEST_SUITE(U32)
shubhame1a4e372019-01-07 21:37:55 +0530184FIXTURE_DATA_TEST_CASE(RunSmall, CLPermuteFixture<uint32_t>, framework::DatasetMode::PRECOMMIT,
185 PermuteParametersSmall * framework::dataset::make("DataType", DataType::U32))
Michalis Spyrou5237e012018-01-17 09:40:27 +0000186{
187 // Validate output
188 validate(CLAccessor(_target), _reference);
189}
shubhame1a4e372019-01-07 21:37:55 +0530190FIXTURE_DATA_TEST_CASE(RunLarge, CLPermuteFixture<uint32_t>, framework::DatasetMode::NIGHTLY,
191 PermuteParametersLarge * framework::dataset::make("DataType", DataType::U32))
Michalis Spyrou5237e012018-01-17 09:40:27 +0000192{
193 // Validate output
194 validate(CLAccessor(_target), _reference);
195}
shubhame1a4e372019-01-07 21:37:55 +0530196TEST_SUITE_END() // U32
Michalis Spyrou5237e012018-01-17 09:40:27 +0000197
shubhame1a4e372019-01-07 21:37:55 +0530198#endif /* DOXYGEN_SKIP_THIS */
199
200TEST_SUITE_END() // Permute
201TEST_SUITE_END() // CL
Michalis Spyrou5237e012018-01-17 09:40:27 +0000202} // namespace validation
203} // namespace test
204} // namespace arm_compute