blob: 83746f324fbc48440387d580b64eea8549990621 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01005
telsoa014fcda012018-03-09 14:13:49 +00006#pragma once
7
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01008#include <ResolveType.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
telsoa014fcda012018-03-09 14:13:49 +000010#include <armnn/ArmNN.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000012#include <armnn/backends/IBackendInternal.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <backendsCommon/WorkloadFactory.hpp>
telsoa014fcda012018-03-09 14:13:49 +000014
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010015#include <backendsCommon/test/WorkloadTestUtils.hpp>
16
17#include <test/TensorHelpers.hpp>
18
telsoa014fcda012018-03-09 14:13:49 +000019template<typename T>
20LayerTestResult<T, 4> SimplePermuteTestImpl(
21 armnn::IWorkloadFactory& workloadFactory,
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +000022 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
telsoa014fcda012018-03-09 14:13:49 +000023 armnn::PermuteDescriptor descriptor,
24 armnn::TensorInfo inputTensorInfo,
25 armnn::TensorInfo outputTensorInfo,
26 const std::vector<T>& inputData,
27 const std::vector<T>& outputExpectedData)
28{
Derek Lambertic374ff02019-12-10 21:57:35 +000029 boost::ignore_unused(memoryManager);
telsoa014fcda012018-03-09 14:13:49 +000030 auto input = MakeTensor<T, 4>(inputTensorInfo, inputData);
31
32 LayerTestResult<T, 4> ret(outputTensorInfo);
33 ret.outputExpected = MakeTensor<T, 4>(outputTensorInfo, outputExpectedData);
34
35 std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
36 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
37
38 armnn::PermuteQueueDescriptor data;
39 data.m_Parameters = descriptor;
40 armnn::WorkloadInfo info;
41 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
42 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
43
44 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreatePermute(data, info);
45
46 inputHandle->Allocate();
47 outputHandle->Allocate();
48
49 CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
50
51 workload->Execute();
52
53 CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get());
54
55 return ret;
56}
57
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +010058template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
59LayerTestResult<T, 4> SimplePermuteTest(
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +000060 armnn::IWorkloadFactory& workloadFactory,
61 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
telsoa014fcda012018-03-09 14:13:49 +000062{
63 armnn::TensorInfo inputTensorInfo;
64 armnn::TensorInfo outputTensorInfo;
65
66 unsigned int inputShape[] = { 1, 2, 2, 2 };
67 unsigned int outputShape[] = { 1, 2, 2, 2 };
68
69 armnn::PermuteDescriptor descriptor;
70 descriptor.m_DimMappings = {0U, 3U, 1U, 2U};
71
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +010072 inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType);
73 outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
telsoa014fcda012018-03-09 14:13:49 +000074
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +010075 // Set quantization parameters if the requested type is a quantized type.
76 if(armnn::IsQuantizedType<T>())
77 {
78 inputTensorInfo.SetQuantizationScale(0.5f);
79 inputTensorInfo.SetQuantizationOffset(5);
80 outputTensorInfo.SetQuantizationScale(0.5f);
81 outputTensorInfo.SetQuantizationOffset(5);
82 }
telsoa014fcda012018-03-09 14:13:49 +000083
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +010084 std::vector<T> input = std::vector<T>(
85 {
86 1, 2,
87 3, 4,
88 5, 6,
89 7, 8
90 });
telsoa014fcda012018-03-09 14:13:49 +000091
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +010092 std::vector<T> outputExpected = std::vector<T>(
93 {
94 1, 5, 2, 6,
95 3, 7, 4, 8
96 });
telsoa014fcda012018-03-09 14:13:49 +000097
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +010098 return SimplePermuteTestImpl<T>(workloadFactory, memoryManager,
99 descriptor, inputTensorInfo,
100 outputTensorInfo, input, outputExpected);
telsoa014fcda012018-03-09 14:13:49 +0000101}
102
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100103template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
104LayerTestResult<T, 4> PermuteValueSet1Test(
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +0000105 armnn::IWorkloadFactory& workloadFactory,
106 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
surmeh01bceff2f2018-03-29 16:29:27 +0100107{
108 armnn::TensorInfo inputTensorInfo;
109 armnn::TensorInfo outputTensorInfo;
110
111 unsigned int inputShape[] = { 1, 2, 2, 3 };
112 unsigned int outputShape[] = { 1, 3, 2, 2 };
113
114 armnn::PermuteDescriptor descriptor;
115 descriptor.m_DimMappings = {0U, 2U, 3U, 1U};
116
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100117 inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType);
118 outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
surmeh01bceff2f2018-03-29 16:29:27 +0100119
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100120 // Set quantization parameters if the requested type is a quantized type.
121 if(armnn::IsQuantizedType<T>())
122 {
123 inputTensorInfo.SetQuantizationScale(0.5f);
124 inputTensorInfo.SetQuantizationOffset(5);
125 outputTensorInfo.SetQuantizationScale(0.5f);
126 outputTensorInfo.SetQuantizationOffset(5);
127 }
surmeh01bceff2f2018-03-29 16:29:27 +0100128
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100129 std::vector<T> input = std::vector<T>(
130 {
131 1, 2, 3,
132 11, 12, 13,
133 21, 22, 23,
134 31, 32, 33
135 });
surmeh01bceff2f2018-03-29 16:29:27 +0100136
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100137 std::vector<T> outputExpected = std::vector<T>(
138 {
139 1, 11, 21, 31,
140 2, 12, 22, 32,
141 3, 13, 23, 33
142 });
143
144 return SimplePermuteTestImpl<T>(workloadFactory, memoryManager,
145 descriptor, inputTensorInfo,
146 outputTensorInfo, input, outputExpected);
surmeh01bceff2f2018-03-29 16:29:27 +0100147}
148
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100149template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
150LayerTestResult<T, 4> PermuteValueSet2Test(
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +0000151 armnn::IWorkloadFactory& workloadFactory,
152 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
surmeh01bceff2f2018-03-29 16:29:27 +0100153{
154 armnn::TensorInfo inputTensorInfo;
155 armnn::TensorInfo outputTensorInfo;
156
157 unsigned int inputShape[] = { 1, 3, 2, 2 };
158 unsigned int outputShape[] = { 1, 2, 2, 3 };
159
160 armnn::PermuteDescriptor descriptor;
161 descriptor.m_DimMappings = {0U, 3U, 1U, 2U};
162
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100163 inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType);
164 outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
surmeh01bceff2f2018-03-29 16:29:27 +0100165
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100166 // Set quantization parameters if the requested type is a quantized type.
167 if(armnn::IsQuantizedType<T>())
168 {
169 inputTensorInfo.SetQuantizationScale(0.5f);
170 inputTensorInfo.SetQuantizationOffset(5);
171 outputTensorInfo.SetQuantizationScale(0.5f);
172 outputTensorInfo.SetQuantizationOffset(5);
173 }
surmeh01bceff2f2018-03-29 16:29:27 +0100174
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100175 std::vector<T> input = std::vector<T>(
176 {
177 1, 11, 21, 31,
178 2, 12, 22, 32,
179 3, 13, 23, 33
180 });
surmeh01bceff2f2018-03-29 16:29:27 +0100181
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100182 std::vector<T> outputExpected = std::vector<T>(
183 {
184 1, 2, 3,
185 11, 12, 13,
186 21, 22, 23,
187 31, 32, 33,
188 });
189
190 return SimplePermuteTestImpl<T>(workloadFactory, memoryManager,
191 descriptor, inputTensorInfo,
192 outputTensorInfo, input, outputExpected);
surmeh01bceff2f2018-03-29 16:29:27 +0100193}
194
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100195template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
196LayerTestResult<T, 4> PermuteValueSet3Test(
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +0000197 armnn::IWorkloadFactory& workloadFactory,
198 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
surmeh01bceff2f2018-03-29 16:29:27 +0100199{
200 armnn::TensorInfo inputTensorInfo;
201 armnn::TensorInfo outputTensorInfo;
202
203 unsigned int inputShape[] = { 1, 2, 3, 3 };
204 unsigned int outputShape[] = { 1, 3, 2, 3 };
205
206 armnn::PermuteDescriptor descriptor;
207 descriptor.m_DimMappings = {0U, 2U, 3U, 1U};
208
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100209 inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType);
210 outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
surmeh01bceff2f2018-03-29 16:29:27 +0100211
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100212 // Set quantization parameters if the requested type is a quantized type.
213 if(armnn::IsQuantizedType<T>())
214 {
215 inputTensorInfo.SetQuantizationScale(0.5f);
216 inputTensorInfo.SetQuantizationOffset(5);
217 outputTensorInfo.SetQuantizationScale(0.5f);
218 outputTensorInfo.SetQuantizationOffset(5);
219 }
220
221 std::vector<T> input = std::vector<T>(
surmeh01bceff2f2018-03-29 16:29:27 +0100222 {
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100223 1, 2, 3,
224 11, 12, 13,
225 21, 22, 23,
226 31, 32, 33,
227 41, 42, 43,
228 51, 52, 53
surmeh01bceff2f2018-03-29 16:29:27 +0100229 });
230
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100231 std::vector<T> outputExpected = std::vector<T>(
surmeh01bceff2f2018-03-29 16:29:27 +0100232 {
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100233 1, 11, 21, 31, 41, 51,
234 2, 12, 22, 32, 42, 52,
235 3, 13, 23, 33, 43, 53
surmeh01bceff2f2018-03-29 16:29:27 +0100236 });
237
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100238 return SimplePermuteTestImpl<T>(workloadFactory, memoryManager,
239 descriptor, inputTensorInfo,
240 outputTensorInfo, input, outputExpected);
surmeh01bceff2f2018-03-29 16:29:27 +0100241}