blob: 911696b9daada34f28f32cdfa32be566480b6ad8 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Teresa Charlinfbf0e5b2020-08-17 01:01:06 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
telsoa014fcda012018-03-09 14:13:49 +00005
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01006#include "SplitterTestImpl.hpp"
7
Aron Virginas-Tar48623a02019-10-22 10:00:28 +01008#include <QuantizeHelper.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01009#include <ResolveType.hpp>
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +000010
telsoa014fcda012018-03-09 14:13:49 +000011
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010012#include <backendsCommon/test/TensorCopyUtils.hpp>
13#include <backendsCommon/test/WorkloadTestUtils.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014
David Beckac42efd2018-09-26 17:41:13 +010015#include <test/TensorHelpers.hpp>
telsoa014fcda012018-03-09 14:13:49 +000016
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010017namespace
18{
19
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000020template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +000021std::vector<LayerTestResult<T,3>> SplitterTestCommon(
22 armnn::IWorkloadFactory& workloadFactory,
23 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
24 float qScale = 0.0f,
25 int32_t qOffset = 0)
telsoa014fcda012018-03-09 14:13:49 +000026{
Jan Eilers8eb25602020-03-09 12:13:48 +000027 IgnoreUnused(memoryManager);
telsoa014fcda012018-03-09 14:13:49 +000028 unsigned int inputWidth = 5;
29 unsigned int inputHeight = 6;
30 unsigned int inputChannels = 3;
31
surmeh013537c2c2018-05-18 16:31:43 +010032 // NOTE: Compute Library imposes a restriction that the x and y dimension (input height and width)
33 // cannot be split.
telsoa01c577f2c2018-08-31 09:22:23 +010034 // For the reasons for this, see first comment on https://jira.arm.com/browse/IVGCVSW-1239
surmeh013537c2c2018-05-18 16:31:43 +010035 //
telsoa01c577f2c2018-08-31 09:22:23 +010036 // This test has therefore been recast to split the channels, then split the resulting subtensor.
telsoa014fcda012018-03-09 14:13:49 +000037
telsoa01c577f2c2018-08-31 09:22:23 +010038 // To take channel 0 of original output
39 // and channel 0 and channel 1 of the split subtensor.
surmeh013537c2c2018-05-18 16:31:43 +010040 unsigned int outputWidth1 = inputWidth;
41 unsigned int outputHeight1 = inputHeight;
42 unsigned int outputChannels1 = 1;
telsoa014fcda012018-03-09 14:13:49 +000043
telsoa01c577f2c2018-08-31 09:22:23 +010044 // To take channel 1 and 2 of the original output.
surmeh013537c2c2018-05-18 16:31:43 +010045 unsigned int outputWidth2 = inputWidth;
46 unsigned int outputHeight2 = inputHeight;
47 unsigned int outputChannels2 = 2;
telsoa014fcda012018-03-09 14:13:49 +000048
49
telsoa01c577f2c2018-08-31 09:22:23 +010050 // Define the tensor descriptors.
Ruomei Yan25339c32019-05-28 16:48:20 +010051 armnn::TensorInfo inputTensorInfo({ inputChannels, inputHeight, inputWidth }, ArmnnType, qScale, qOffset);
surmeh013537c2c2018-05-18 16:31:43 +010052
telsoa01c577f2c2018-08-31 09:22:23 +010053 // Outputs of the original split.
Ruomei Yan25339c32019-05-28 16:48:20 +010054 armnn::TensorInfo outputTensorInfo1({ outputChannels1, outputHeight1, outputWidth1 }, ArmnnType, qScale, qOffset);
55 armnn::TensorInfo outputTensorInfo2({ outputChannels2, outputHeight2, outputWidth2 }, ArmnnType, qScale, qOffset);
surmeh013537c2c2018-05-18 16:31:43 +010056
telsoa01c577f2c2018-08-31 09:22:23 +010057 // Outputs of the subsequent subtensor split.
Ruomei Yan25339c32019-05-28 16:48:20 +010058 armnn::TensorInfo outputTensorInfo3({ outputChannels1, outputHeight1, outputWidth1 }, ArmnnType, qScale, qOffset);
59 armnn::TensorInfo outputTensorInfo4({ outputChannels1, outputHeight1, outputWidth1 }, ArmnnType, qScale, qOffset);
telsoa014fcda012018-03-09 14:13:49 +000060
61 // Set quantization parameters if the requested type is a quantized type.
telsoa01c577f2c2018-08-31 09:22:23 +010062 // The quantization doesn't really matter as the splitter operator doesn't dequantize/quantize.
telsoa014fcda012018-03-09 14:13:49 +000063 if(armnn::IsQuantizedType<T>())
64 {
65 inputTensorInfo.SetQuantizationScale(qScale);
66 inputTensorInfo.SetQuantizationOffset(qOffset);
67 outputTensorInfo1.SetQuantizationScale(qScale);
68 outputTensorInfo1.SetQuantizationOffset(qOffset);
69 outputTensorInfo2.SetQuantizationScale(qScale);
70 outputTensorInfo2.SetQuantizationOffset(qOffset);
71 outputTensorInfo3.SetQuantizationScale(qScale);
72 outputTensorInfo3.SetQuantizationOffset(qOffset);
73 outputTensorInfo4.SetQuantizationScale(qScale);
74 outputTensorInfo4.SetQuantizationOffset(qOffset);
telsoa014fcda012018-03-09 14:13:49 +000075 }
76
77 LayerTestResult<T,3> ret1(outputTensorInfo1);
78 LayerTestResult<T,3> ret2(outputTensorInfo2);
79 LayerTestResult<T,3> ret3(outputTensorInfo3);
80 LayerTestResult<T,3> ret4(outputTensorInfo4);
telsoa014fcda012018-03-09 14:13:49 +000081
82 auto input = MakeTensor<T, 3>(inputTensorInfo, std::vector<T>(
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010083 armnnUtils::QuantizedVector<T>({
telsoa014fcda012018-03-09 14:13:49 +000084 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
85 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
86 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
87 16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
88 21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
89 26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
90
91 31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
92 36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
93 41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
94 46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
95 51.0f, 52.0f, 53.0f, 54.0f, 55.0f,
96 56.0f, 57.0f, 58.0f, 59.0f, 60.0f,
97
98 61.0f, 62.0f, 63.0f, 64.0f, 65.0f,
99 66.0f, 67.0f, 68.0f, 69.0f, 70.0f,
100 71.0f, 72.0f, 73.0f, 74.0f, 75.0f,
101 76.0f, 77.0f, 78.0f, 79.0f, 80.0f,
102 81.0f, 82.0f, 83.0f, 84.0f, 85.0f,
103 86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
Aron Virginas-Tar48623a02019-10-22 10:00:28 +0100104 },
105 qScale, qOffset)
telsoa014fcda012018-03-09 14:13:49 +0000106 ));
107
telsoa01c577f2c2018-08-31 09:22:23 +0100108 // Channel 0 of the original input.
telsoa014fcda012018-03-09 14:13:49 +0000109 ret1.outputExpected = MakeTensor<T, 3>(outputTensorInfo1, std::vector<T>(
Aron Virginas-Tar48623a02019-10-22 10:00:28 +0100110 armnnUtils::QuantizedVector<T>({
surmeh013537c2c2018-05-18 16:31:43 +0100111 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
112 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
113 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
114 16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
115 21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
116 26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
Aron Virginas-Tar48623a02019-10-22 10:00:28 +0100117 },
118 qScale, qOffset)
telsoa014fcda012018-03-09 14:13:49 +0000119 ));
120
telsoa01c577f2c2018-08-31 09:22:23 +0100121 // Channel 1 & 2 of the original input.
telsoa014fcda012018-03-09 14:13:49 +0000122 ret2.outputExpected = MakeTensor<T, 3>(outputTensorInfo2, std::vector<T>(
Aron Virginas-Tar48623a02019-10-22 10:00:28 +0100123 armnnUtils::QuantizedVector<T>({
surmeh013537c2c2018-05-18 16:31:43 +0100124 31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
125 36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
126 41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
127 46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
128 51.0f, 52.0f, 53.0f, 54.0f, 55.0f,
129 56.0f, 57.0f, 58.0f, 59.0f, 60.0f,
telsoa014fcda012018-03-09 14:13:49 +0000130
surmeh013537c2c2018-05-18 16:31:43 +0100131 61.0f, 62.0f, 63.0f, 64.0f, 65.0f,
132 66.0f, 67.0f, 68.0f, 69.0f, 70.0f,
133 71.0f, 72.0f, 73.0f, 74.0f, 75.0f,
134 76.0f, 77.0f, 78.0f, 79.0f, 80.0f,
135 81.0f, 82.0f, 83.0f, 84.0f, 85.0f,
136 86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
Aron Virginas-Tar48623a02019-10-22 10:00:28 +0100137 },
138 qScale, qOffset)
telsoa014fcda012018-03-09 14:13:49 +0000139 ));
140
telsoa01c577f2c2018-08-31 09:22:23 +0100141 // Channel 0 of return 2 (i.e. channels 1 and 2 of the original input).
telsoa014fcda012018-03-09 14:13:49 +0000142 ret3.outputExpected = MakeTensor<T, 3>(outputTensorInfo3, std::vector<T>(
Aron Virginas-Tar48623a02019-10-22 10:00:28 +0100143 armnnUtils::QuantizedVector<T>({
surmeh013537c2c2018-05-18 16:31:43 +0100144 31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
145 36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
146 41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
147 46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
148 51.0f, 52.0f, 53.0f, 54.0f, 55.0f,
149 56.0f, 57.0f, 58.0f, 59.0f, 60.0f,
Aron Virginas-Tar48623a02019-10-22 10:00:28 +0100150 },
151 qScale, qOffset)
telsoa014fcda012018-03-09 14:13:49 +0000152 ));
153
telsoa01c577f2c2018-08-31 09:22:23 +0100154 // Channel 1 of return 2.
telsoa014fcda012018-03-09 14:13:49 +0000155 ret4.outputExpected = MakeTensor<T, 3>(outputTensorInfo4, std::vector<T>(
Aron Virginas-Tar48623a02019-10-22 10:00:28 +0100156 armnnUtils::QuantizedVector<T>({
surmeh013537c2c2018-05-18 16:31:43 +0100157 61.0f, 62.0f, 63.0f, 64.0f, 65.0f,
158 66.0f, 67.0f, 68.0f, 69.0f, 70.0f,
159 71.0f, 72.0f, 73.0f, 74.0f, 75.0f,
160 76.0f, 77.0f, 78.0f, 79.0f, 80.0f,
161 81.0f, 82.0f, 83.0f, 84.0f, 85.0f,
162 86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
Aron Virginas-Tar48623a02019-10-22 10:00:28 +0100163 },
164 qScale, qOffset)
telsoa014fcda012018-03-09 14:13:49 +0000165 ));
166
telsoa01c577f2c2018-08-31 09:22:23 +0100167 // NOTE: as a corollary of the splitting of x and y restriction the x and y values of the view origins
surmeh013537c2c2018-05-18 16:31:43 +0100168 // have to be zero, the co-ordinates are as per the tensor info above channels, height/y, width/x
telsoa01c577f2c2018-08-31 09:22:23 +0100169 // note that under the hood the compute engine reverses these i.e. its coordinate system is x, y, channels.
170 std::vector<unsigned int> wOrigin1 = {0, 0, 0}; //Extent of the window is defined by size of output[0].
telsoa014fcda012018-03-09 14:13:49 +0000171 armnn::SplitterQueueDescriptor::ViewOrigin window1(wOrigin1);
172
telsoa01c577f2c2018-08-31 09:22:23 +0100173 std::vector<unsigned int> wOrigin2 = {1, 0, 0}; //Extent of the window is defined by size of output[1].
telsoa014fcda012018-03-09 14:13:49 +0000174 armnn::SplitterQueueDescriptor::ViewOrigin window2(wOrigin2);
175
telsoa01c577f2c2018-08-31 09:22:23 +0100176 std::vector<unsigned int> wOrigin3 = {0, 0, 0}; //Extent of the window is defined by size of output[2].
telsoa014fcda012018-03-09 14:13:49 +0000177 armnn::SplitterQueueDescriptor::ViewOrigin window3(wOrigin3);
178
telsoa01c577f2c2018-08-31 09:22:23 +0100179 std::vector<unsigned int> wOrigin4 = {1, 0, 0}; //Extent of the window is defined by size of output[3].
telsoa014fcda012018-03-09 14:13:49 +0000180 armnn::SplitterQueueDescriptor::ViewOrigin window4(wOrigin4);
181
182 bool subTensorsSupported = workloadFactory.SupportsSubTensors();
183
Teresa Charlinfbf0e5b2020-08-17 01:01:06 +0100184 ARMNN_NO_DEPRECATE_WARN_BEGIN
telsoa014fcda012018-03-09 14:13:49 +0000185 std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
186
187 std::unique_ptr<armnn::ITensorHandle> outputHandle1 =
188 subTensorsSupported ?
189 workloadFactory.CreateSubTensorHandle(*inputHandle, outputTensorInfo1.GetShape(), wOrigin1.data()) :
190 workloadFactory.CreateTensorHandle(outputTensorInfo1);
191
192 std::unique_ptr<armnn::ITensorHandle> outputHandle2 =
193 subTensorsSupported ?
194 workloadFactory.CreateSubTensorHandle(*inputHandle, outputTensorInfo2.GetShape(), wOrigin2.data()) :
195 workloadFactory.CreateTensorHandle(outputTensorInfo2);
196
197 std::unique_ptr<armnn::ITensorHandle> outputHandle3 =
198 subTensorsSupported ?
surmeh013537c2c2018-05-18 16:31:43 +0100199 workloadFactory.CreateSubTensorHandle(*outputHandle2, outputTensorInfo3.GetShape(), wOrigin3.data()) :
telsoa014fcda012018-03-09 14:13:49 +0000200 workloadFactory.CreateTensorHandle(outputTensorInfo3);
201
202 std::unique_ptr<armnn::ITensorHandle> outputHandle4 =
203 subTensorsSupported ?
surmeh013537c2c2018-05-18 16:31:43 +0100204 workloadFactory.CreateSubTensorHandle(*outputHandle2, outputTensorInfo4.GetShape(), wOrigin4.data()) :
telsoa014fcda012018-03-09 14:13:49 +0000205 workloadFactory.CreateTensorHandle(outputTensorInfo4);
Teresa Charlinfbf0e5b2020-08-17 01:01:06 +0100206 ARMNN_NO_DEPRECATE_WARN_END
telsoa014fcda012018-03-09 14:13:49 +0000207
surmeh013537c2c2018-05-18 16:31:43 +0100208 // Do the first split
telsoa014fcda012018-03-09 14:13:49 +0000209 armnn::SplitterQueueDescriptor data;
210 armnn::WorkloadInfo info;
211 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
212 AddOutputToWorkload(data, info, outputTensorInfo1, outputHandle1.get());
213 AddOutputToWorkload(data, info, outputTensorInfo2, outputHandle2.get());
telsoa014fcda012018-03-09 14:13:49 +0000214
215 data.m_ViewOrigins.push_back(window1);
216 data.m_ViewOrigins.push_back(window2);
telsoa014fcda012018-03-09 14:13:49 +0000217
218 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateSplitter(data, info);
219
220 inputHandle->Allocate();
221 outputHandle1->Allocate();
222 outputHandle2->Allocate();
telsoa014fcda012018-03-09 14:13:49 +0000223
224 CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0]);
225
226 workload->Execute();
227
228 CopyDataFromITensorHandle(&ret1.output[0][0][0], outputHandle1.get());
229 CopyDataFromITensorHandle(&ret2.output[0][0][0], outputHandle2.get());
surmeh013537c2c2018-05-18 16:31:43 +0100230
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +0000231 // Do the second split.
surmeh013537c2c2018-05-18 16:31:43 +0100232 armnn::SplitterQueueDescriptor data2;
233 armnn::WorkloadInfo info2;
234 AddInputToWorkload(data2, info2, outputTensorInfo2, outputHandle2.get());
235 AddOutputToWorkload(data2, info2, outputTensorInfo3, outputHandle3.get());
236 AddOutputToWorkload(data2, info2, outputTensorInfo4, outputHandle4.get());
237
238 data2.m_ViewOrigins.push_back(window3);
239 data2.m_ViewOrigins.push_back(window4);
240
241 std::unique_ptr<armnn::IWorkload> workload2 = workloadFactory.CreateSplitter(data2, info2);
242
243 outputHandle3->Allocate();
244 outputHandle4->Allocate();
245
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +0000246 ExecuteWorkload(*workload2, memoryManager);
surmeh013537c2c2018-05-18 16:31:43 +0100247
telsoa014fcda012018-03-09 14:13:49 +0000248 CopyDataFromITensorHandle(&ret3.output[0][0][0], outputHandle3.get());
249 CopyDataFromITensorHandle(&ret4.output[0][0][0], outputHandle4.get());
telsoa014fcda012018-03-09 14:13:49 +0000250
surmeh013537c2c2018-05-18 16:31:43 +0100251 std::vector<LayerTestResult<T,3>> ret = {ret1, ret2, ret3, ret4,};
telsoa014fcda012018-03-09 14:13:49 +0000252
253 return ret;
254}
255
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +0000256template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +0000257LayerTestResult<T, 3> CopyViaSplitterTestImpl(
258 armnn::IWorkloadFactory& workloadFactory,
259 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
260 float qScale, int32_t qOffset)
telsoa014fcda012018-03-09 14:13:49 +0000261{
Jan Eilers8eb25602020-03-09 12:13:48 +0000262 IgnoreUnused(memoryManager);
Ruomei Yan25339c32019-05-28 16:48:20 +0100263 const armnn::TensorInfo tensorInfo({ 3, 6, 5 }, ArmnnType, qScale, qOffset);
Aron Virginas-Tar48623a02019-10-22 10:00:28 +0100264 auto input = MakeTensor<T, 3>(
265 tensorInfo,
266 armnnUtils::QuantizedVector<T>({
267 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
268 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
269 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
270 16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
271 21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
272 26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
telsoa014fcda012018-03-09 14:13:49 +0000273
Aron Virginas-Tar48623a02019-10-22 10:00:28 +0100274 31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
275 36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
276 41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
277 46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
278 51.0f, 52.0f, 53.0f, 54.0f, 55.0f,
279 56.0f, 57.0f, 58.0f, 59.0f, 60.0f,
telsoa014fcda012018-03-09 14:13:49 +0000280
Aron Virginas-Tar48623a02019-10-22 10:00:28 +0100281 61.0f, 62.0f, 63.0f, 64.0f, 65.0f,
282 66.0f, 67.0f, 68.0f, 69.0f, 70.0f,
283 71.0f, 72.0f, 73.0f, 74.0f, 75.0f,
284 76.0f, 77.0f, 78.0f, 79.0f, 80.0f,
285 81.0f, 82.0f, 83.0f, 84.0f, 85.0f,
286 86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
287 },
288 qScale, qOffset));
telsoa014fcda012018-03-09 14:13:49 +0000289
290 std::vector<unsigned int> origin = { 0, 0, 0 };
291 armnn::SplitterQueueDescriptor::ViewOrigin window(origin);
292
293 const bool subTensorsSupported = workloadFactory.SupportsSubTensors();
294
Teresa Charlinfbf0e5b2020-08-17 01:01:06 +0100295 ARMNN_NO_DEPRECATE_WARN_BEGIN
telsoa014fcda012018-03-09 14:13:49 +0000296 std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(tensorInfo);
297
298 std::unique_ptr<armnn::ITensorHandle> outputHandle =
299 subTensorsSupported ?
300 workloadFactory.CreateSubTensorHandle(*inputHandle, tensorInfo.GetShape(), origin.data()) :
301 workloadFactory.CreateTensorHandle(tensorInfo);
Teresa Charlinfbf0e5b2020-08-17 01:01:06 +0100302 ARMNN_NO_DEPRECATE_WARN_END
telsoa014fcda012018-03-09 14:13:49 +0000303
304 armnn::SplitterQueueDescriptor data;
305 armnn::WorkloadInfo info;
306 AddInputToWorkload(data, info, tensorInfo, inputHandle.get());
307 AddOutputToWorkload(data, info, tensorInfo, outputHandle.get());
308
309 data.m_ViewOrigins.push_back(window);
310
311 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateSplitter(data, info);
312
313 inputHandle->Allocate();
314 outputHandle->Allocate();
315
316 CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0]);
317
318 workload->Execute();
319
320 LayerTestResult<T, 3> ret(tensorInfo);
321 CopyDataFromITensorHandle(&ret.output[0][0][0], outputHandle.get());
322 ret.outputExpected = input;
323
324 return ret;
325}
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100326
327} // anonymous namespace
328
Matthew Jackson9bff1442019-09-12 09:08:23 +0100329std::vector<LayerTestResult<float,3>> SplitterFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100330 armnn::IWorkloadFactory& workloadFactory,
331 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
332{
333 return SplitterTestCommon<armnn::DataType::Float32>(workloadFactory, memoryManager);
334}
335
Matthew Jackson9bff1442019-09-12 09:08:23 +0100336std::vector<LayerTestResult<armnn::Half,3>> SplitterFloat16Test(
337 armnn::IWorkloadFactory& workloadFactory,
338 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
339{
340 return SplitterTestCommon<armnn::DataType::Float16>(workloadFactory, memoryManager);
341}
342
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100343std::vector<LayerTestResult<uint8_t,3>> SplitterUint8Test(
344 armnn::IWorkloadFactory& workloadFactory,
345 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
346{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000347 return SplitterTestCommon<armnn::DataType::QAsymmU8>(workloadFactory, memoryManager, 1.0f, 0);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100348}
349
350std::vector<LayerTestResult<int16_t,3>> SplitterInt16Test(
351 armnn::IWorkloadFactory& workloadFactory,
352 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
353{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000354 return SplitterTestCommon<armnn::DataType::QSymmS16>(workloadFactory, memoryManager, 1.0f, 0);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100355}
356
Matthew Jackson9bff1442019-09-12 09:08:23 +0100357LayerTestResult<float, 3> CopyViaSplitterFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100358 armnn::IWorkloadFactory& workloadFactory,
359 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
360{
361 return CopyViaSplitterTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, 0.0f, 0);
362}
363
Matthew Jackson9bff1442019-09-12 09:08:23 +0100364LayerTestResult<armnn::Half, 3> CopyViaSplitterFloat16Test(
365 armnn::IWorkloadFactory& workloadFactory,
366 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
367{
368 return CopyViaSplitterTestImpl<armnn::DataType::Float16>(workloadFactory, memoryManager, 0.0f, 0);
369}
370
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100371LayerTestResult<uint8_t, 3> CopyViaSplitterUint8Test(
372 armnn::IWorkloadFactory& workloadFactory,
373 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
374{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000375 return CopyViaSplitterTestImpl<armnn::DataType::QAsymmU8>(workloadFactory, memoryManager, 1.0f, 0);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100376}
377
378LayerTestResult<int16_t, 3> CopyViaSplitterInt16Test(
379 armnn::IWorkloadFactory& workloadFactory,
380 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
381{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000382 return CopyViaSplitterTestImpl<armnn::DataType::QSymmS16>(workloadFactory, memoryManager, 1.0f, 0);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100383}