blob: 1e528a0fdcae2bba4e300bdd938c58006512da5c [file] [log] [blame]
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001//
Teresa Charlinfbf0e5b2020-08-17 01:01:06 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "ReshapeTestImpl.hpp"
7
8#include <backendsCommon/test/DataTypeUtils.hpp>
9#include <backendsCommon/test/TensorCopyUtils.hpp>
10#include <backendsCommon/test/WorkloadTestUtils.hpp>
11
12#include <test/TensorHelpers.hpp>
13
14namespace
15{
16
17template<typename T, size_t NumDims>
18LayerTestResult<T, NumDims> SimpleReshapeTestImpl(
19 armnn::IWorkloadFactory& workloadFactory,
20 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
21 armnn::TensorInfo inputTensorInfo,
22 armnn::TensorInfo outputTensorInfo,
23 const std::vector<T>& inputData,
24 const std::vector<T>& outputExpectedData)
25{
Jan Eilers8eb25602020-03-09 12:13:48 +000026 IgnoreUnused(memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010027 auto input = MakeTensor<T, NumDims>(inputTensorInfo, inputData);
28
29 LayerTestResult<T, NumDims> ret(outputTensorInfo);
30 ret.outputExpected = MakeTensor<T, NumDims>(outputTensorInfo, outputExpectedData);
31
Teresa Charlinfbf0e5b2020-08-17 01:01:06 +010032 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010033 std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
34 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
Teresa Charlinfbf0e5b2020-08-17 01:01:06 +010035 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010036
37 armnn::ReshapeQueueDescriptor data;
38 armnn::WorkloadInfo info;
39 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
40 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
41
42 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateReshape(data, info);
43
44 inputHandle->Allocate();
45 outputHandle->Allocate();
46
47 CopyDataToITensorHandle(inputHandle.get(), input.origin());
48
49 workload->Execute();
50
51 CopyDataFromITensorHandle(ret.output.origin(), outputHandle.get());
52
53 return ret;
54}
55
56} // anonymous namespace
57
58template<armnn::DataType ArmnnType, typename T>
59LayerTestResult<T, 4> SimpleReshapeTest(
60 armnn::IWorkloadFactory& workloadFactory,
61 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
62{
63 armnn::TensorInfo inputTensorInfo;
64 armnn::TensorInfo outputTensorInfo;
65
66 unsigned int inputShape[] = { 2, 2, 3, 3 };
67 unsigned int outputShape[] = { 2, 2, 9, 1 };
68
69 inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType);
70 inputTensorInfo.SetQuantizationScale(1.0f);
71 outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
72 outputTensorInfo.SetQuantizationScale(1.0f);
73
74 auto input = ConvertToDataType<ArmnnType>(
75 {
76 0.0f, 1.0f, 2.0f,
77 3.0f, 4.0f, 5.0f,
78 6.0f, 7.0f, 8.0f,
79
80 9.0f, 10.0f, 11.0f,
81 12.0f, 13.0f, 14.0f,
82 15.0f, 16.0f, 17.0f,
83
84 18.0f, 19.0f, 20.0f,
85 21.0f, 22.0f, 23.0f,
86 24.0f, 25.0f, 26.0f,
87
88 27.0f, 28.0f, 29.0f,
89 30.0f, 31.0f, 32.0f,
90 33.0f, 34.0f, 35.0f,
91 },
92 inputTensorInfo);
93
94 auto outputExpected = ConvertToDataType<ArmnnType>(
95 {
96 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
97
98 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f,
99
100 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f,
101
102 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
103 },
104 outputTensorInfo);
105
106 return SimpleReshapeTestImpl<T, 4>(
107 workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, input, outputExpected);
108}
109
110template<armnn::DataType ArmnnType, typename T>
111LayerTestResult<T, 5> Reshape5dTest(
112 armnn::IWorkloadFactory& workloadFactory,
113 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
114{
115 armnn::TensorInfo inputTensorInfo;
116 armnn::TensorInfo outputTensorInfo;
117
118 unsigned int inputShape[] = { 2, 2, 8, 1, 1 };
119 unsigned int outputShape[] = { 2, 2, 2, 2, 2 };
120
121 inputTensorInfo = armnn::TensorInfo(5, inputShape, ArmnnType);
122 inputTensorInfo.SetQuantizationScale(1.0f);
123 outputTensorInfo = armnn::TensorInfo(5, outputShape, ArmnnType);
124 outputTensorInfo.SetQuantizationScale(1.0f);
125
126 auto input = ConvertToDataType<ArmnnType>(
127 {
128 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
129 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
130
131 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f,
132 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f,
133 },
134 inputTensorInfo);
135
136 auto outputExpected = ConvertToDataType<ArmnnType>(
137 {
138 0.0f, 1.0f,
139 2.0f, 3.0f,
140
141 4.0f, 5.0f,
142 6.0f, 7.0f,
143
144
145 8.0f, 9.0f,
146 10.0f, 11.0f,
147
148 12.0f, 13.0f,
149 14.0f, 15.0f,
150
151
152
153 16.0f, 17.0f,
154 18.0f, 19.0f,
155
156 20.0f, 21.0f,
157 22.0f, 23.0f,
158
159
160 24.0f, 25.0f,
161 26.0f, 27.0f,
162
163 28.0f, 29.0f,
164 30.0f, 31.0f,
165 },
166 outputTensorInfo);
167
168 return SimpleReshapeTestImpl<T, 5>(
169 workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, input, outputExpected);
170}
171
172//
173// Explicit template specializations
174//
175
176template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
177SimpleReshapeTest<armnn::DataType::Float32>(
178 armnn::IWorkloadFactory& workloadFactory,
179 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
180
Sadik Armagan303980c2020-04-17 12:45:14 +0100181template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 4>
182SimpleReshapeTest<armnn::DataType::QAsymmS8>(
183 armnn::IWorkloadFactory& workloadFactory,
184 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
185
Derek Lambertif90c56d2020-01-10 17:14:08 +0000186template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 4>
187SimpleReshapeTest<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100188 armnn::IWorkloadFactory& workloadFactory,
189 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
190
Derek Lambertif90c56d2020-01-10 17:14:08 +0000191template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 4>
192SimpleReshapeTest<armnn::DataType::QSymmS16>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100193 armnn::IWorkloadFactory& workloadFactory,
194 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
195
196template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 5>
197Reshape5dTest<armnn::DataType::Float32>(
198 armnn::IWorkloadFactory& workloadFactory,
199 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
200
Sadik Armagan303980c2020-04-17 12:45:14 +0100201template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 5>
202Reshape5dTest<armnn::DataType::QAsymmS8>(
203 armnn::IWorkloadFactory& workloadFactory,
204 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
205
Derek Lambertif90c56d2020-01-10 17:14:08 +0000206template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 5>
207Reshape5dTest<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100208 armnn::IWorkloadFactory& workloadFactory,
209 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
210
Derek Lambertif90c56d2020-01-10 17:14:08 +0000211template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 5>
212Reshape5dTest<armnn::DataType::QSymmS16>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100213 armnn::IWorkloadFactory& workloadFactory,
214 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);