blob: bce24f06bcb16552254524e4a7c0db3187eabc3a [file] [log] [blame]
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ReshapeTestImpl.hpp"
7
8#include <backendsCommon/test/DataTypeUtils.hpp>
9#include <backendsCommon/test/TensorCopyUtils.hpp>
10#include <backendsCommon/test/WorkloadTestUtils.hpp>
11
12#include <test/TensorHelpers.hpp>
13
14namespace
15{
16
17template<typename T, size_t NumDims>
18LayerTestResult<T, NumDims> SimpleReshapeTestImpl(
19 armnn::IWorkloadFactory& workloadFactory,
20 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
21 armnn::TensorInfo inputTensorInfo,
22 armnn::TensorInfo outputTensorInfo,
23 const std::vector<T>& inputData,
24 const std::vector<T>& outputExpectedData)
25{
26 auto input = MakeTensor<T, NumDims>(inputTensorInfo, inputData);
27
28 LayerTestResult<T, NumDims> ret(outputTensorInfo);
29 ret.outputExpected = MakeTensor<T, NumDims>(outputTensorInfo, outputExpectedData);
30
31 std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
32 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
33
34 armnn::ReshapeQueueDescriptor data;
35 armnn::WorkloadInfo info;
36 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
37 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
38
39 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateReshape(data, info);
40
41 inputHandle->Allocate();
42 outputHandle->Allocate();
43
44 CopyDataToITensorHandle(inputHandle.get(), input.origin());
45
46 workload->Execute();
47
48 CopyDataFromITensorHandle(ret.output.origin(), outputHandle.get());
49
50 return ret;
51}
52
53} // anonymous namespace
54
55template<armnn::DataType ArmnnType, typename T>
56LayerTestResult<T, 4> SimpleReshapeTest(
57 armnn::IWorkloadFactory& workloadFactory,
58 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
59{
60 armnn::TensorInfo inputTensorInfo;
61 armnn::TensorInfo outputTensorInfo;
62
63 unsigned int inputShape[] = { 2, 2, 3, 3 };
64 unsigned int outputShape[] = { 2, 2, 9, 1 };
65
66 inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType);
67 inputTensorInfo.SetQuantizationScale(1.0f);
68 outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
69 outputTensorInfo.SetQuantizationScale(1.0f);
70
71 auto input = ConvertToDataType<ArmnnType>(
72 {
73 0.0f, 1.0f, 2.0f,
74 3.0f, 4.0f, 5.0f,
75 6.0f, 7.0f, 8.0f,
76
77 9.0f, 10.0f, 11.0f,
78 12.0f, 13.0f, 14.0f,
79 15.0f, 16.0f, 17.0f,
80
81 18.0f, 19.0f, 20.0f,
82 21.0f, 22.0f, 23.0f,
83 24.0f, 25.0f, 26.0f,
84
85 27.0f, 28.0f, 29.0f,
86 30.0f, 31.0f, 32.0f,
87 33.0f, 34.0f, 35.0f,
88 },
89 inputTensorInfo);
90
91 auto outputExpected = ConvertToDataType<ArmnnType>(
92 {
93 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
94
95 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f,
96
97 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f,
98
99 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
100 },
101 outputTensorInfo);
102
103 return SimpleReshapeTestImpl<T, 4>(
104 workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, input, outputExpected);
105}
106
107template<armnn::DataType ArmnnType, typename T>
108LayerTestResult<T, 5> Reshape5dTest(
109 armnn::IWorkloadFactory& workloadFactory,
110 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
111{
112 armnn::TensorInfo inputTensorInfo;
113 armnn::TensorInfo outputTensorInfo;
114
115 unsigned int inputShape[] = { 2, 2, 8, 1, 1 };
116 unsigned int outputShape[] = { 2, 2, 2, 2, 2 };
117
118 inputTensorInfo = armnn::TensorInfo(5, inputShape, ArmnnType);
119 inputTensorInfo.SetQuantizationScale(1.0f);
120 outputTensorInfo = armnn::TensorInfo(5, outputShape, ArmnnType);
121 outputTensorInfo.SetQuantizationScale(1.0f);
122
123 auto input = ConvertToDataType<ArmnnType>(
124 {
125 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
126 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
127
128 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f,
129 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f,
130 },
131 inputTensorInfo);
132
133 auto outputExpected = ConvertToDataType<ArmnnType>(
134 {
135 0.0f, 1.0f,
136 2.0f, 3.0f,
137
138 4.0f, 5.0f,
139 6.0f, 7.0f,
140
141
142 8.0f, 9.0f,
143 10.0f, 11.0f,
144
145 12.0f, 13.0f,
146 14.0f, 15.0f,
147
148
149
150 16.0f, 17.0f,
151 18.0f, 19.0f,
152
153 20.0f, 21.0f,
154 22.0f, 23.0f,
155
156
157 24.0f, 25.0f,
158 26.0f, 27.0f,
159
160 28.0f, 29.0f,
161 30.0f, 31.0f,
162 },
163 outputTensorInfo);
164
165 return SimpleReshapeTestImpl<T, 5>(
166 workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, input, outputExpected);
167}
168
169//
170// Explicit template specializations
171//
172
173template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
174SimpleReshapeTest<armnn::DataType::Float32>(
175 armnn::IWorkloadFactory& workloadFactory,
176 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
177
178template LayerTestResult<armnn::ResolveType<armnn::DataType::QuantisedAsymm8>, 4>
179SimpleReshapeTest<armnn::DataType::QuantisedAsymm8>(
180 armnn::IWorkloadFactory& workloadFactory,
181 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
182
183template LayerTestResult<armnn::ResolveType<armnn::DataType::QuantisedSymm16>, 4>
184SimpleReshapeTest<armnn::DataType::QuantisedSymm16>(
185 armnn::IWorkloadFactory& workloadFactory,
186 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
187
188template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 5>
189Reshape5dTest<armnn::DataType::Float32>(
190 armnn::IWorkloadFactory& workloadFactory,
191 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
192
193template LayerTestResult<armnn::ResolveType<armnn::DataType::QuantisedAsymm8>, 5>
194Reshape5dTest<armnn::DataType::QuantisedAsymm8>(
195 armnn::IWorkloadFactory& workloadFactory,
196 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
197
198template LayerTestResult<armnn::ResolveType<armnn::DataType::QuantisedSymm16>, 5>
199Reshape5dTest<armnn::DataType::QuantisedSymm16>(
200 armnn::IWorkloadFactory& workloadFactory,
201 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);