blob: 979d0a7f730c6757ed2632951a4f2bc63bce2bf2 [file] [log] [blame]
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ReshapeTestImpl.hpp"
7
8#include <backendsCommon/test/DataTypeUtils.hpp>
9#include <backendsCommon/test/TensorCopyUtils.hpp>
10#include <backendsCommon/test/WorkloadTestUtils.hpp>
11
12#include <test/TensorHelpers.hpp>
13
14namespace
15{
16
17template<typename T, size_t NumDims>
18LayerTestResult<T, NumDims> SimpleReshapeTestImpl(
19 armnn::IWorkloadFactory& workloadFactory,
20 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
21 armnn::TensorInfo inputTensorInfo,
22 armnn::TensorInfo outputTensorInfo,
23 const std::vector<T>& inputData,
24 const std::vector<T>& outputExpectedData)
25{
Jan Eilers8eb25602020-03-09 12:13:48 +000026 IgnoreUnused(memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010027 auto input = MakeTensor<T, NumDims>(inputTensorInfo, inputData);
28
29 LayerTestResult<T, NumDims> ret(outputTensorInfo);
30 ret.outputExpected = MakeTensor<T, NumDims>(outputTensorInfo, outputExpectedData);
31
32 std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
33 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
34
35 armnn::ReshapeQueueDescriptor data;
36 armnn::WorkloadInfo info;
37 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
38 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
39
40 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateReshape(data, info);
41
42 inputHandle->Allocate();
43 outputHandle->Allocate();
44
45 CopyDataToITensorHandle(inputHandle.get(), input.origin());
46
47 workload->Execute();
48
49 CopyDataFromITensorHandle(ret.output.origin(), outputHandle.get());
50
51 return ret;
52}
53
54} // anonymous namespace
55
56template<armnn::DataType ArmnnType, typename T>
57LayerTestResult<T, 4> SimpleReshapeTest(
58 armnn::IWorkloadFactory& workloadFactory,
59 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
60{
61 armnn::TensorInfo inputTensorInfo;
62 armnn::TensorInfo outputTensorInfo;
63
64 unsigned int inputShape[] = { 2, 2, 3, 3 };
65 unsigned int outputShape[] = { 2, 2, 9, 1 };
66
67 inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType);
68 inputTensorInfo.SetQuantizationScale(1.0f);
69 outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
70 outputTensorInfo.SetQuantizationScale(1.0f);
71
72 auto input = ConvertToDataType<ArmnnType>(
73 {
74 0.0f, 1.0f, 2.0f,
75 3.0f, 4.0f, 5.0f,
76 6.0f, 7.0f, 8.0f,
77
78 9.0f, 10.0f, 11.0f,
79 12.0f, 13.0f, 14.0f,
80 15.0f, 16.0f, 17.0f,
81
82 18.0f, 19.0f, 20.0f,
83 21.0f, 22.0f, 23.0f,
84 24.0f, 25.0f, 26.0f,
85
86 27.0f, 28.0f, 29.0f,
87 30.0f, 31.0f, 32.0f,
88 33.0f, 34.0f, 35.0f,
89 },
90 inputTensorInfo);
91
92 auto outputExpected = ConvertToDataType<ArmnnType>(
93 {
94 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
95
96 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f,
97
98 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f,
99
100 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
101 },
102 outputTensorInfo);
103
104 return SimpleReshapeTestImpl<T, 4>(
105 workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, input, outputExpected);
106}
107
108template<armnn::DataType ArmnnType, typename T>
109LayerTestResult<T, 5> Reshape5dTest(
110 armnn::IWorkloadFactory& workloadFactory,
111 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
112{
113 armnn::TensorInfo inputTensorInfo;
114 armnn::TensorInfo outputTensorInfo;
115
116 unsigned int inputShape[] = { 2, 2, 8, 1, 1 };
117 unsigned int outputShape[] = { 2, 2, 2, 2, 2 };
118
119 inputTensorInfo = armnn::TensorInfo(5, inputShape, ArmnnType);
120 inputTensorInfo.SetQuantizationScale(1.0f);
121 outputTensorInfo = armnn::TensorInfo(5, outputShape, ArmnnType);
122 outputTensorInfo.SetQuantizationScale(1.0f);
123
124 auto input = ConvertToDataType<ArmnnType>(
125 {
126 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
127 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
128
129 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f,
130 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f,
131 },
132 inputTensorInfo);
133
134 auto outputExpected = ConvertToDataType<ArmnnType>(
135 {
136 0.0f, 1.0f,
137 2.0f, 3.0f,
138
139 4.0f, 5.0f,
140 6.0f, 7.0f,
141
142
143 8.0f, 9.0f,
144 10.0f, 11.0f,
145
146 12.0f, 13.0f,
147 14.0f, 15.0f,
148
149
150
151 16.0f, 17.0f,
152 18.0f, 19.0f,
153
154 20.0f, 21.0f,
155 22.0f, 23.0f,
156
157
158 24.0f, 25.0f,
159 26.0f, 27.0f,
160
161 28.0f, 29.0f,
162 30.0f, 31.0f,
163 },
164 outputTensorInfo);
165
166 return SimpleReshapeTestImpl<T, 5>(
167 workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, input, outputExpected);
168}
169
170//
171// Explicit template specializations
172//
173
174template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
175SimpleReshapeTest<armnn::DataType::Float32>(
176 armnn::IWorkloadFactory& workloadFactory,
177 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
178
Sadik Armagan303980c2020-04-17 12:45:14 +0100179template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 4>
180SimpleReshapeTest<armnn::DataType::QAsymmS8>(
181 armnn::IWorkloadFactory& workloadFactory,
182 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
183
Derek Lambertif90c56d2020-01-10 17:14:08 +0000184template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 4>
185SimpleReshapeTest<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100186 armnn::IWorkloadFactory& workloadFactory,
187 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
188
Derek Lambertif90c56d2020-01-10 17:14:08 +0000189template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 4>
190SimpleReshapeTest<armnn::DataType::QSymmS16>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100191 armnn::IWorkloadFactory& workloadFactory,
192 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
193
194template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 5>
195Reshape5dTest<armnn::DataType::Float32>(
196 armnn::IWorkloadFactory& workloadFactory,
197 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
198
Sadik Armagan303980c2020-04-17 12:45:14 +0100199template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 5>
200Reshape5dTest<armnn::DataType::QAsymmS8>(
201 armnn::IWorkloadFactory& workloadFactory,
202 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
203
Derek Lambertif90c56d2020-01-10 17:14:08 +0000204template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 5>
205Reshape5dTest<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100206 armnn::IWorkloadFactory& workloadFactory,
207 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
208
Derek Lambertif90c56d2020-01-10 17:14:08 +0000209template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 5>
210Reshape5dTest<armnn::DataType::QSymmS16>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100211 armnn::IWorkloadFactory& workloadFactory,
212 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);