blob: d233e89be8d0f9ae33cbb305ba4f5c1512c8c9e2 [file] [log] [blame]
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001//
Teresa Charlinfbf0e5b2020-08-17 01:01:06 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "ReshapeTestImpl.hpp"
7
8#include <backendsCommon/test/DataTypeUtils.hpp>
9#include <backendsCommon/test/TensorCopyUtils.hpp>
10#include <backendsCommon/test/WorkloadTestUtils.hpp>
11
12#include <test/TensorHelpers.hpp>
13
14namespace
15{
16
17template<typename T, size_t NumDims>
18LayerTestResult<T, NumDims> SimpleReshapeTestImpl(
19 armnn::IWorkloadFactory& workloadFactory,
20 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsec36d3e2020-08-28 13:17:05 +010021 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010022 armnn::TensorInfo inputTensorInfo,
23 armnn::TensorInfo outputTensorInfo,
24 const std::vector<T>& inputData,
25 const std::vector<T>& outputExpectedData)
26{
Jan Eilers8eb25602020-03-09 12:13:48 +000027 IgnoreUnused(memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010028 auto input = MakeTensor<T, NumDims>(inputTensorInfo, inputData);
29
30 LayerTestResult<T, NumDims> ret(outputTensorInfo);
31 ret.outputExpected = MakeTensor<T, NumDims>(outputTensorInfo, outputExpectedData);
32
Finn Williamsec36d3e2020-08-28 13:17:05 +010033 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
34 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010035
36 armnn::ReshapeQueueDescriptor data;
37 armnn::WorkloadInfo info;
38 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
39 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
40
41 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateReshape(data, info);
42
43 inputHandle->Allocate();
44 outputHandle->Allocate();
45
46 CopyDataToITensorHandle(inputHandle.get(), input.origin());
47
48 workload->Execute();
49
50 CopyDataFromITensorHandle(ret.output.origin(), outputHandle.get());
51
52 return ret;
53}
54
55} // anonymous namespace
56
57template<armnn::DataType ArmnnType, typename T>
58LayerTestResult<T, 4> SimpleReshapeTest(
59 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsec36d3e2020-08-28 13:17:05 +010060 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
61 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010062{
63 armnn::TensorInfo inputTensorInfo;
64 armnn::TensorInfo outputTensorInfo;
65
66 unsigned int inputShape[] = { 2, 2, 3, 3 };
67 unsigned int outputShape[] = { 2, 2, 9, 1 };
68
69 inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType);
70 inputTensorInfo.SetQuantizationScale(1.0f);
71 outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
72 outputTensorInfo.SetQuantizationScale(1.0f);
73
74 auto input = ConvertToDataType<ArmnnType>(
75 {
76 0.0f, 1.0f, 2.0f,
77 3.0f, 4.0f, 5.0f,
78 6.0f, 7.0f, 8.0f,
79
80 9.0f, 10.0f, 11.0f,
81 12.0f, 13.0f, 14.0f,
82 15.0f, 16.0f, 17.0f,
83
84 18.0f, 19.0f, 20.0f,
85 21.0f, 22.0f, 23.0f,
86 24.0f, 25.0f, 26.0f,
87
88 27.0f, 28.0f, 29.0f,
89 30.0f, 31.0f, 32.0f,
90 33.0f, 34.0f, 35.0f,
91 },
92 inputTensorInfo);
93
94 auto outputExpected = ConvertToDataType<ArmnnType>(
95 {
96 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
97
98 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f,
99
100 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f,
101
102 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
103 },
104 outputTensorInfo);
105
106 return SimpleReshapeTestImpl<T, 4>(
Finn Williamsec36d3e2020-08-28 13:17:05 +0100107 workloadFactory, memoryManager, tensorHandleFactory, inputTensorInfo, outputTensorInfo, input, outputExpected);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100108}
109
110template<armnn::DataType ArmnnType, typename T>
111LayerTestResult<T, 5> Reshape5dTest(
112 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsec36d3e2020-08-28 13:17:05 +0100113 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
114 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100115{
116 armnn::TensorInfo inputTensorInfo;
117 armnn::TensorInfo outputTensorInfo;
118
119 unsigned int inputShape[] = { 2, 2, 8, 1, 1 };
120 unsigned int outputShape[] = { 2, 2, 2, 2, 2 };
121
122 inputTensorInfo = armnn::TensorInfo(5, inputShape, ArmnnType);
123 inputTensorInfo.SetQuantizationScale(1.0f);
124 outputTensorInfo = armnn::TensorInfo(5, outputShape, ArmnnType);
125 outputTensorInfo.SetQuantizationScale(1.0f);
126
127 auto input = ConvertToDataType<ArmnnType>(
128 {
129 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
130 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
131
132 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f,
133 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f,
134 },
135 inputTensorInfo);
136
137 auto outputExpected = ConvertToDataType<ArmnnType>(
138 {
139 0.0f, 1.0f,
140 2.0f, 3.0f,
141
142 4.0f, 5.0f,
143 6.0f, 7.0f,
144
145
146 8.0f, 9.0f,
147 10.0f, 11.0f,
148
149 12.0f, 13.0f,
150 14.0f, 15.0f,
151
152
153
154 16.0f, 17.0f,
155 18.0f, 19.0f,
156
157 20.0f, 21.0f,
158 22.0f, 23.0f,
159
160
161 24.0f, 25.0f,
162 26.0f, 27.0f,
163
164 28.0f, 29.0f,
165 30.0f, 31.0f,
166 },
167 outputTensorInfo);
168
169 return SimpleReshapeTestImpl<T, 5>(
Finn Williamsec36d3e2020-08-28 13:17:05 +0100170 workloadFactory, memoryManager, tensorHandleFactory, inputTensorInfo, outputTensorInfo, input, outputExpected);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100171}
172
173//
174// Explicit template specializations
175//
176
177template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
178SimpleReshapeTest<armnn::DataType::Float32>(
179 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsec36d3e2020-08-28 13:17:05 +0100180 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
181 const armnn::ITensorHandleFactory& tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100182
Sadik Armagan303980c2020-04-17 12:45:14 +0100183template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 4>
184SimpleReshapeTest<armnn::DataType::QAsymmS8>(
185 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsec36d3e2020-08-28 13:17:05 +0100186 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
187 const armnn::ITensorHandleFactory& tensorHandleFactory);
Sadik Armagan303980c2020-04-17 12:45:14 +0100188
Derek Lambertif90c56d2020-01-10 17:14:08 +0000189template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 4>
190SimpleReshapeTest<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100191 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsec36d3e2020-08-28 13:17:05 +0100192 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
193 const armnn::ITensorHandleFactory& tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100194
Derek Lambertif90c56d2020-01-10 17:14:08 +0000195template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 4>
196SimpleReshapeTest<armnn::DataType::QSymmS16>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100197 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsec36d3e2020-08-28 13:17:05 +0100198 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
199 const armnn::ITensorHandleFactory& tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100200
201template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 5>
202Reshape5dTest<armnn::DataType::Float32>(
203 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsec36d3e2020-08-28 13:17:05 +0100204 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
205 const armnn::ITensorHandleFactory& tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100206
Sadik Armagan303980c2020-04-17 12:45:14 +0100207template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 5>
208Reshape5dTest<armnn::DataType::QAsymmS8>(
209 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsec36d3e2020-08-28 13:17:05 +0100210 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
211 const armnn::ITensorHandleFactory& tensorHandleFactory);
Sadik Armagan303980c2020-04-17 12:45:14 +0100212
Derek Lambertif90c56d2020-01-10 17:14:08 +0000213template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 5>
214Reshape5dTest<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100215 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsec36d3e2020-08-28 13:17:05 +0100216 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
217 const armnn::ITensorHandleFactory& tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100218
Derek Lambertif90c56d2020-01-10 17:14:08 +0000219template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 5>
220Reshape5dTest<armnn::DataType::QSymmS16>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100221 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsec36d3e2020-08-28 13:17:05 +0100222 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
223 const armnn::ITensorHandleFactory& tensorHandleFactory);