blob: 31ad8c17176d2e6f848739d24bdff7b97eb5ef52 [file] [log] [blame]
Michalis Spyrouceb889e2018-09-17 18:24:41 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/Types.h"
25#include "arm_compute/runtime/CL/CLTensor.h"
26#include "arm_compute/runtime/CL/CLTensorAllocator.h"
27#include "arm_compute/runtime/CL/functions/CLUpsampleLayer.h"
28#include "tests/CL/CLAccessor.h"
29#include "tests/PaddingCalculator.h"
30#include "tests/datasets/ShapeDatasets.h"
31#include "tests/framework/Asserts.h"
32#include "tests/framework/Macros.h"
33#include "tests/framework/datasets/Datasets.h"
34#include "tests/validation/Validation.h"
35#include "tests/validation/fixtures/UpsampleLayerFixture.h"
36
37namespace arm_compute
38{
39namespace test
40{
41namespace validation
42{
43namespace
44{
45constexpr AbsoluteTolerance<float> tolerance(0.001f);
46} // namespace
47
48TEST_SUITE(CL)
49TEST_SUITE(UpsampleLayer)
50
51DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, (combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32))),
52 input_shape, data_type)
53{
54 InterpolationPolicy upsampling_policy = InterpolationPolicy::NEAREST_NEIGHBOR;
55 Size2D info = Size2D(2, 2);
56
57 // Create tensors
58 CLTensor src = create_tensor<CLTensor>(input_shape, data_type, 1);
59 CLTensor dst;
60
61 ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
62 ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
63
64 // Create and configure function
65 CLUpsampleLayer upsample;
66 upsample.configure(&src, &dst, info, upsampling_policy);
67
68 // Validate valid region
69 const ValidRegion src_valid_region = shape_to_valid_region(src.info()->tensor_shape());
70 const ValidRegion dst_valid_region = shape_to_valid_region(dst.info()->tensor_shape());
71
72 validate(src.info()->valid_region(), src_valid_region);
73 validate(dst.info()->valid_region(), dst_valid_region);
74}
75
76// *INDENT-OFF*
77// clang-format off
78DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
79 framework::dataset::make("InputInfo", { TensorInfo(TensorShape(10U, 10U, 2U), 1, DataType::F32), // Mismatching data type
80 TensorInfo(TensorShape(10U, 10U, 2U), 1, DataType::F32), // Invalid output shape
81 TensorInfo(TensorShape(10U, 10U, 2U), 1, DataType::F32), // Invalid stride
82 TensorInfo(TensorShape(10U, 10U, 2U), 1, DataType::F32), // Invalid policy
83 TensorInfo(TensorShape(10U, 10U, 2U), 1, DataType::F32),
84 }),
85 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(20U, 20U, 2U), 1, DataType::F16),
86 TensorInfo(TensorShape(20U, 10U, 2U), 1, DataType::F32),
87 TensorInfo(TensorShape(20U, 20U, 2U), 1, DataType::F32),
88 TensorInfo(TensorShape(20U, 20U, 2U), 1, DataType::F32),
89 TensorInfo(TensorShape(20U, 20U, 2U), 1, DataType::F32),
90 })),
91 framework::dataset::make("PadInfo", { Size2D(2, 2),
92 Size2D(2, 2),
93 Size2D(1, 1),
94 Size2D(2, 2),
95 Size2D(2, 2),
96 })),
97 framework::dataset::make("UpsamplingPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR,
98 InterpolationPolicy::NEAREST_NEIGHBOR,
99 InterpolationPolicy::NEAREST_NEIGHBOR,
100 InterpolationPolicy::BILINEAR,
101 InterpolationPolicy::NEAREST_NEIGHBOR,
102 })),
103 framework::dataset::make("Expected", { false, false, false, false, true })),
104 input_info, output_info, pad_info, upsampling_policy, expected)
105{
106 bool is_valid = bool(CLUpsampleLayer::validate(&input_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), pad_info, upsampling_policy));
107 ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS);
108}
109// clang-format on
110// *INDENT-ON*
111
112template <typename T>
113using CLUpsampleLayerFixture = UpsampleLayerFixture<CLTensor, CLAccessor, CLUpsampleLayer, T>;
114
115TEST_SUITE(Float)
116TEST_SUITE(FP32)
117FIXTURE_DATA_TEST_CASE(RunSmall, CLUpsampleLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
118 framework::dataset::make("DataType", DataType::F32)),
119 framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
120 framework::dataset::make("PadInfo", { Size2D(2, 2) })),
121 framework::dataset::make("UpsamplingPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR })))
122{
123 // Validate output
124 validate(CLAccessor(_target), _reference, tolerance);
125}
126TEST_SUITE_END() // FP32
127
128TEST_SUITE(FP16)
129
130FIXTURE_DATA_TEST_CASE(RunSmall, CLUpsampleLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
131 framework::dataset::make("DataType",
132 DataType::F16)),
133 framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
134 framework::dataset::make("PadInfo", { Size2D(2, 2) })),
135 framework::dataset::make("UpsamplingPolicy", { InterpolationPolicy::NEAREST_NEIGHBOR })))
136{
137 // Validate output
138 validate(CLAccessor(_target), _reference, tolerance);
139}
140
141TEST_SUITE_END() // FP16
142TEST_SUITE_END() // Float
143
144TEST_SUITE_END() // UpsampleLayer
145TEST_SUITE_END() // CL
146} // namespace validation
147} // namespace test
148} // namespace arm_compute