blob: c5039a12021c086e4fb71030b205dd73c38350bb [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01005
telsoa014fcda012018-03-09 14:13:49 +00006#pragma once
7
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01008#include <ResolveType.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
telsoa014fcda012018-03-09 14:13:49 +000010#include <armnn/ArmNN.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000012#include <armnn/backends/IBackendInternal.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <backendsCommon/WorkloadFactory.hpp>
telsoa014fcda012018-03-09 14:13:49 +000014
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010015#include <backendsCommon/test/WorkloadTestUtils.hpp>
16
17#include <test/TensorHelpers.hpp>
18
telsoa014fcda012018-03-09 14:13:49 +000019template<typename T>
20LayerTestResult<T, 4> SimplePermuteTestImpl(
21 armnn::IWorkloadFactory& workloadFactory,
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +000022 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
telsoa014fcda012018-03-09 14:13:49 +000023 armnn::PermuteDescriptor descriptor,
24 armnn::TensorInfo inputTensorInfo,
25 armnn::TensorInfo outputTensorInfo,
26 const std::vector<T>& inputData,
27 const std::vector<T>& outputExpectedData)
28{
29 auto input = MakeTensor<T, 4>(inputTensorInfo, inputData);
30
31 LayerTestResult<T, 4> ret(outputTensorInfo);
32 ret.outputExpected = MakeTensor<T, 4>(outputTensorInfo, outputExpectedData);
33
34 std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
35 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
36
37 armnn::PermuteQueueDescriptor data;
38 data.m_Parameters = descriptor;
39 armnn::WorkloadInfo info;
40 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
41 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
42
43 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreatePermute(data, info);
44
45 inputHandle->Allocate();
46 outputHandle->Allocate();
47
48 CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
49
50 workload->Execute();
51
52 CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get());
53
54 return ret;
55}
56
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +010057template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
58LayerTestResult<T, 4> SimplePermuteTest(
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +000059 armnn::IWorkloadFactory& workloadFactory,
60 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
telsoa014fcda012018-03-09 14:13:49 +000061{
62 armnn::TensorInfo inputTensorInfo;
63 armnn::TensorInfo outputTensorInfo;
64
65 unsigned int inputShape[] = { 1, 2, 2, 2 };
66 unsigned int outputShape[] = { 1, 2, 2, 2 };
67
68 armnn::PermuteDescriptor descriptor;
69 descriptor.m_DimMappings = {0U, 3U, 1U, 2U};
70
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +010071 inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType);
72 outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
telsoa014fcda012018-03-09 14:13:49 +000073
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +010074 // Set quantization parameters if the requested type is a quantized type.
75 if(armnn::IsQuantizedType<T>())
76 {
77 inputTensorInfo.SetQuantizationScale(0.5f);
78 inputTensorInfo.SetQuantizationOffset(5);
79 outputTensorInfo.SetQuantizationScale(0.5f);
80 outputTensorInfo.SetQuantizationOffset(5);
81 }
telsoa014fcda012018-03-09 14:13:49 +000082
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +010083 std::vector<T> input = std::vector<T>(
84 {
85 1, 2,
86 3, 4,
87 5, 6,
88 7, 8
89 });
telsoa014fcda012018-03-09 14:13:49 +000090
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +010091 std::vector<T> outputExpected = std::vector<T>(
92 {
93 1, 5, 2, 6,
94 3, 7, 4, 8
95 });
telsoa014fcda012018-03-09 14:13:49 +000096
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +010097 return SimplePermuteTestImpl<T>(workloadFactory, memoryManager,
98 descriptor, inputTensorInfo,
99 outputTensorInfo, input, outputExpected);
telsoa014fcda012018-03-09 14:13:49 +0000100}
101
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100102template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
103LayerTestResult<T, 4> PermuteValueSet1Test(
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +0000104 armnn::IWorkloadFactory& workloadFactory,
105 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
surmeh01bceff2f2018-03-29 16:29:27 +0100106{
107 armnn::TensorInfo inputTensorInfo;
108 armnn::TensorInfo outputTensorInfo;
109
110 unsigned int inputShape[] = { 1, 2, 2, 3 };
111 unsigned int outputShape[] = { 1, 3, 2, 2 };
112
113 armnn::PermuteDescriptor descriptor;
114 descriptor.m_DimMappings = {0U, 2U, 3U, 1U};
115
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100116 inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType);
117 outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
surmeh01bceff2f2018-03-29 16:29:27 +0100118
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100119 // Set quantization parameters if the requested type is a quantized type.
120 if(armnn::IsQuantizedType<T>())
121 {
122 inputTensorInfo.SetQuantizationScale(0.5f);
123 inputTensorInfo.SetQuantizationOffset(5);
124 outputTensorInfo.SetQuantizationScale(0.5f);
125 outputTensorInfo.SetQuantizationOffset(5);
126 }
surmeh01bceff2f2018-03-29 16:29:27 +0100127
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100128 std::vector<T> input = std::vector<T>(
129 {
130 1, 2, 3,
131 11, 12, 13,
132 21, 22, 23,
133 31, 32, 33
134 });
surmeh01bceff2f2018-03-29 16:29:27 +0100135
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100136 std::vector<T> outputExpected = std::vector<T>(
137 {
138 1, 11, 21, 31,
139 2, 12, 22, 32,
140 3, 13, 23, 33
141 });
142
143 return SimplePermuteTestImpl<T>(workloadFactory, memoryManager,
144 descriptor, inputTensorInfo,
145 outputTensorInfo, input, outputExpected);
surmeh01bceff2f2018-03-29 16:29:27 +0100146}
147
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100148template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
149LayerTestResult<T, 4> PermuteValueSet2Test(
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +0000150 armnn::IWorkloadFactory& workloadFactory,
151 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
surmeh01bceff2f2018-03-29 16:29:27 +0100152{
153 armnn::TensorInfo inputTensorInfo;
154 armnn::TensorInfo outputTensorInfo;
155
156 unsigned int inputShape[] = { 1, 3, 2, 2 };
157 unsigned int outputShape[] = { 1, 2, 2, 3 };
158
159 armnn::PermuteDescriptor descriptor;
160 descriptor.m_DimMappings = {0U, 3U, 1U, 2U};
161
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100162 inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType);
163 outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
surmeh01bceff2f2018-03-29 16:29:27 +0100164
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100165 // Set quantization parameters if the requested type is a quantized type.
166 if(armnn::IsQuantizedType<T>())
167 {
168 inputTensorInfo.SetQuantizationScale(0.5f);
169 inputTensorInfo.SetQuantizationOffset(5);
170 outputTensorInfo.SetQuantizationScale(0.5f);
171 outputTensorInfo.SetQuantizationOffset(5);
172 }
surmeh01bceff2f2018-03-29 16:29:27 +0100173
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100174 std::vector<T> input = std::vector<T>(
175 {
176 1, 11, 21, 31,
177 2, 12, 22, 32,
178 3, 13, 23, 33
179 });
surmeh01bceff2f2018-03-29 16:29:27 +0100180
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100181 std::vector<T> outputExpected = std::vector<T>(
182 {
183 1, 2, 3,
184 11, 12, 13,
185 21, 22, 23,
186 31, 32, 33,
187 });
188
189 return SimplePermuteTestImpl<T>(workloadFactory, memoryManager,
190 descriptor, inputTensorInfo,
191 outputTensorInfo, input, outputExpected);
surmeh01bceff2f2018-03-29 16:29:27 +0100192}
193
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100194template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
195LayerTestResult<T, 4> PermuteValueSet3Test(
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +0000196 armnn::IWorkloadFactory& workloadFactory,
197 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
surmeh01bceff2f2018-03-29 16:29:27 +0100198{
199 armnn::TensorInfo inputTensorInfo;
200 armnn::TensorInfo outputTensorInfo;
201
202 unsigned int inputShape[] = { 1, 2, 3, 3 };
203 unsigned int outputShape[] = { 1, 3, 2, 3 };
204
205 armnn::PermuteDescriptor descriptor;
206 descriptor.m_DimMappings = {0U, 2U, 3U, 1U};
207
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100208 inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType);
209 outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
surmeh01bceff2f2018-03-29 16:29:27 +0100210
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100211 // Set quantization parameters if the requested type is a quantized type.
212 if(armnn::IsQuantizedType<T>())
213 {
214 inputTensorInfo.SetQuantizationScale(0.5f);
215 inputTensorInfo.SetQuantizationOffset(5);
216 outputTensorInfo.SetQuantizationScale(0.5f);
217 outputTensorInfo.SetQuantizationOffset(5);
218 }
219
220 std::vector<T> input = std::vector<T>(
surmeh01bceff2f2018-03-29 16:29:27 +0100221 {
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100222 1, 2, 3,
223 11, 12, 13,
224 21, 22, 23,
225 31, 32, 33,
226 41, 42, 43,
227 51, 52, 53
surmeh01bceff2f2018-03-29 16:29:27 +0100228 });
229
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100230 std::vector<T> outputExpected = std::vector<T>(
surmeh01bceff2f2018-03-29 16:29:27 +0100231 {
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100232 1, 11, 21, 31, 41, 51,
233 2, 12, 22, 32, 42, 52,
234 3, 13, 23, 33, 43, 53
surmeh01bceff2f2018-03-29 16:29:27 +0100235 });
236
Narumol Prangnawarat86bb4e12019-07-08 11:36:05 +0100237 return SimplePermuteTestImpl<T>(workloadFactory, memoryManager,
238 descriptor, inputTensorInfo,
239 outputTensorInfo, input, outputExpected);
surmeh01bceff2f2018-03-29 16:29:27 +0100240}