blob: 13ac7fc233b8acf65e40efac3f2b181a77d50ac9 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Teresa Charlinacb3ec52023-04-03 19:57:00 +01002// Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
arovir0143095f32018-10-09 18:04:24 +01005
Sadik Armagana097d2a2021-11-24 15:47:28 +00006#include <CreateWorkload.hpp>
arovir0143095f32018-10-09 18:04:24 +01007
Jan Eilersbb446e52020-04-02 13:56:54 +01008#include <armnn/utility/PolymorphicDowncast.hpp>
Matthew Bentham4cefc412019-06-18 16:14:34 +01009#include <reference/RefTensorHandle.hpp>
Teresa Charlin788e2a62022-01-17 21:19:52 +000010#include <reference/RefTensorHandleFactory.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000011#include <reference/RefWorkloadFactory.hpp>
12#include <reference/workloads/RefWorkloads.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Sadik Armagan1625efc2021-06-10 18:24:34 +010014#include <doctest/doctest.h>
15
telsoa014fcda012018-03-09 14:13:49 +000016namespace
17{
18
19template<typename Workload>
20void CheckInputOutput(std::unique_ptr<Workload> workload, const TensorInfo& inputInfo, const TensorInfo& outputInfo)
21{
22 auto queueDescriptor = workload->GetData();
Jan Eilersbb446e52020-04-02 13:56:54 +010023 auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
24 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +010025 CHECK((inputHandle->GetTensorInfo() == inputInfo));
26 CHECK((outputHandle->GetTensorInfo() == outputInfo));
telsoa014fcda012018-03-09 14:13:49 +000027}
28
29template <typename Workload>
30void CheckInputsOutput(std::unique_ptr<Workload> workload,
31 const TensorInfo& inputInfo0,
32 const TensorInfo& inputInfo1,
33 const TensorInfo& outputInfo)
34{
35 auto queueDescriptor = workload->GetData();
Jan Eilersbb446e52020-04-02 13:56:54 +010036 auto inputHandle0 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
37 auto inputHandle1 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[1]);
38 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +010039 CHECK((inputHandle0->GetTensorInfo() == inputInfo0));
40 CHECK((inputHandle1->GetTensorInfo() == inputInfo1));
41 CHECK((outputHandle->GetTensorInfo() == outputInfo));
telsoa014fcda012018-03-09 14:13:49 +000042}
Matthew Bentham7c1603a2019-06-21 17:22:23 +010043
44armnn::RefWorkloadFactory GetFactory()
45{
46 std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
47 return RefWorkloadFactory(memoryManager);
48}
49
telsoa014fcda012018-03-09 14:13:49 +000050}
51
Sadik Armagan1625efc2021-06-10 18:24:34 +010052TEST_SUITE("CreateWorkloadRef")
53{
telsoa01c577f2c2018-08-31 09:22:23 +010054template <typename ActivationWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +000055static void RefCreateActivationWorkloadTest()
56{
57 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +010058 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +010059 auto workload = CreateActivationWorkloadTest<ActivationWorkloadType, DataType>(factory, graph);
telsoa014fcda012018-03-09 14:13:49 +000060
telsoa01c577f2c2018-08-31 09:22:23 +010061 // Checks that outputs are as we expect them (see definition of CreateActivationWorkloadTest).
telsoa014fcda012018-03-09 14:13:49 +000062 CheckInputOutput(std::move(workload),
telsoa01c577f2c2018-08-31 09:22:23 +010063 TensorInfo({ 1, 1 }, DataType),
64 TensorInfo({ 1, 1 }, DataType));
telsoa014fcda012018-03-09 14:13:49 +000065}
66
Sadik Armagan1625efc2021-06-10 18:24:34 +010067TEST_CASE("CreateActivationFloat32Workload")
telsoa014fcda012018-03-09 14:13:49 +000068{
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +010069 RefCreateActivationWorkloadTest<RefActivationWorkload, armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +000070}
71
Sadik Armagan1625efc2021-06-10 18:24:34 +010072TEST_CASE("CreateActivationUint8Workload")
telsoa014fcda012018-03-09 14:13:49 +000073{
Derek Lambertif90c56d2020-01-10 17:14:08 +000074 RefCreateActivationWorkloadTest<RefActivationWorkload, armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +000075}
76
David Beckbc392452018-09-10 14:47:28 +010077template <typename WorkloadType,
78 typename DescriptorType,
79 typename LayerType,
80 armnn::DataType DataType>
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000081static void RefCreateElementwiseWorkloadTest()
telsoa014fcda012018-03-09 14:13:49 +000082{
83 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +010084 RefWorkloadFactory factory = GetFactory();
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000085 auto workload = CreateElementwiseWorkloadTest<WorkloadType, DescriptorType, LayerType, DataType>(
86 factory, graph);
telsoa014fcda012018-03-09 14:13:49 +000087
telsoa014fcda012018-03-09 14:13:49 +000088 CheckInputsOutput(std::move(workload),
telsoa01c577f2c2018-08-31 09:22:23 +010089 TensorInfo({ 2, 3 }, DataType),
90 TensorInfo({ 2, 3 }, DataType),
91 TensorInfo({ 2, 3 }, DataType));
telsoa014fcda012018-03-09 14:13:49 +000092}
93
Sadik Armagan1625efc2021-06-10 18:24:34 +010094TEST_CASE("CreateSubtractionWorkloadWithBlobTest")
Keith Davisdf04d232020-10-23 17:20:05 +010095{
96 Graph graph;
97 RefWorkloadFactory factory = GetFactory();
98 armnn::DataType DataType = armnn::DataType::Float32;
99
100 auto workload = CreateSubtractionWithBlobWorkloadTest<RefSubtractionWorkload<>,
101 SubtractionQueueDescriptor,
102 armnn::DataType::Float32>
103 (factory, graph);
104
105 CheckInputsOutput(std::move(workload),
106 TensorInfo({ 2, 3 }, DataType),
107 TensorInfo({ 2, 3 }, DataType),
108 TensorInfo({ 2, 3 }, DataType));
109}
110
Sadik Armagan1625efc2021-06-10 18:24:34 +0100111TEST_CASE("CreateAdditionWorkloadWithBlobTest")
Keith Davisdf04d232020-10-23 17:20:05 +0100112{
113 Graph graph;
114 RefWorkloadFactory factory = GetFactory();
115 armnn::DataType DataType = armnn::DataType::Float32;
116
117 auto workload = CreateAdditionWithBlobWorkloadTest<RefAdditionWorkload<>,
118 AdditionQueueDescriptor,
119 armnn::DataType::Float32>(factory, graph);
120
121 CheckInputsOutput(std::move(workload),
122 TensorInfo({ 2, 3 }, DataType),
123 TensorInfo({ 2, 3 }, DataType),
124 TensorInfo({ 2, 3 }, DataType));
125}
126
Sadik Armagan1625efc2021-06-10 18:24:34 +0100127TEST_CASE("CreateMultiplicationWorkloadWithBlobTest")
Keith Davisdf04d232020-10-23 17:20:05 +0100128{
129 Graph graph;
130 RefWorkloadFactory factory = GetFactory();
131 armnn::DataType DataType = armnn::DataType::Float32;
132
133 auto workload = CreateMultiplicationWithBlobWorkloadTest<RefMultiplicationWorkload<>,
134 MultiplicationQueueDescriptor,
135 armnn::DataType::Float32>(factory, graph);
136
137 CheckInputsOutput(std::move(workload),
138 TensorInfo({2, 3}, DataType),
139 TensorInfo({2, 3}, DataType),
140 TensorInfo({2, 3}, DataType));
141}
142
Sadik Armagan1625efc2021-06-10 18:24:34 +0100143TEST_CASE("CreateAdditionFloatWorkload")
telsoa014fcda012018-03-09 14:13:49 +0000144{
Finn Williamscbd2c232020-06-22 15:58:32 +0100145 RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000146 AdditionQueueDescriptor,
147 AdditionLayer,
148 armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000149}
150
Sadik Armagan1625efc2021-06-10 18:24:34 +0100151TEST_CASE("CreateAdditionUint8Workload")
telsoa014fcda012018-03-09 14:13:49 +0000152{
Finn Williamscbd2c232020-06-22 15:58:32 +0100153 RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000154 AdditionQueueDescriptor,
155 AdditionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000156 armnn::DataType::QAsymmU8>();
David Beckbc392452018-09-10 14:47:28 +0100157}
158
Sadik Armagan1625efc2021-06-10 18:24:34 +0100159TEST_CASE("CreateAdditionInt16Workload")
Sadik Armagan2999a022019-04-09 14:20:12 +0100160{
Finn Williamscbd2c232020-06-22 15:58:32 +0100161 RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
Sadik Armagan2999a022019-04-09 14:20:12 +0100162 AdditionQueueDescriptor,
163 AdditionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000164 armnn::DataType::QSymmS16>();
Sadik Armagan2999a022019-04-09 14:20:12 +0100165}
166
Sadik Armagan1625efc2021-06-10 18:24:34 +0100167TEST_CASE("CreateAdditionInt32Workload")
Finn Williamscbd2c232020-06-22 15:58:32 +0100168{
169 RefCreateElementwiseWorkloadTest<RefAdditionWorkload<int32_t>,
170 AdditionQueueDescriptor,
171 AdditionLayer,
172 armnn::DataType::Signed32>();
173}
174
Sadik Armagan1625efc2021-06-10 18:24:34 +0100175TEST_CASE("CreateSubtractionFloat32Workload")
David Beckbc392452018-09-10 14:47:28 +0100176{
Finn Williamscbd2c232020-06-22 15:58:32 +0100177 RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000178 SubtractionQueueDescriptor,
179 SubtractionLayer,
180 armnn::DataType::Float32>();
David Beckbc392452018-09-10 14:47:28 +0100181}
182
Sadik Armagan1625efc2021-06-10 18:24:34 +0100183TEST_CASE("CreateSubtractionFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100184{
Finn Williamscbd2c232020-06-22 15:58:32 +0100185 RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100186 SubtractionQueueDescriptor,
187 SubtractionLayer,
188 armnn::DataType::Float16>();
189}
190
Sadik Armagan1625efc2021-06-10 18:24:34 +0100191TEST_CASE("CreateSubtractionUint8Workload")
David Beckbc392452018-09-10 14:47:28 +0100192{
Finn Williamscbd2c232020-06-22 15:58:32 +0100193 RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000194 SubtractionQueueDescriptor,
195 SubtractionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000196 armnn::DataType::QAsymmU8>();
David Beckbc392452018-09-10 14:47:28 +0100197}
198
Sadik Armagan1625efc2021-06-10 18:24:34 +0100199TEST_CASE("CreateSubtractionInt16Workload")
Sadik Armagan2999a022019-04-09 14:20:12 +0100200{
Finn Williamscbd2c232020-06-22 15:58:32 +0100201 RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
Sadik Armagan2999a022019-04-09 14:20:12 +0100202 SubtractionQueueDescriptor,
203 SubtractionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000204 armnn::DataType::QSymmS16>();
Sadik Armagan2999a022019-04-09 14:20:12 +0100205}
206
Sadik Armagan1625efc2021-06-10 18:24:34 +0100207TEST_CASE("CreateSubtractionInt32Workload")
Finn Williamscbd2c232020-06-22 15:58:32 +0100208{
209 RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<int32_t>,
210 SubtractionQueueDescriptor,
211 SubtractionLayer,
212 armnn::DataType::Signed32>();
213}
214
Sadik Armagan1625efc2021-06-10 18:24:34 +0100215TEST_CASE("CreateMultiplicationFloatWorkload")
David Beckbc392452018-09-10 14:47:28 +0100216{
Finn Williamscbd2c232020-06-22 15:58:32 +0100217 RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000218 MultiplicationQueueDescriptor,
219 MultiplicationLayer,
220 armnn::DataType::Float32>();
David Beckbc392452018-09-10 14:47:28 +0100221}
222
Sadik Armagan1625efc2021-06-10 18:24:34 +0100223TEST_CASE("CreateMultiplicationUint8Workload")
David Beckbc392452018-09-10 14:47:28 +0100224{
Finn Williamscbd2c232020-06-22 15:58:32 +0100225 RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000226 MultiplicationQueueDescriptor,
227 MultiplicationLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000228 armnn::DataType::QAsymmU8>();
David Beckbc392452018-09-10 14:47:28 +0100229}
230
Sadik Armagan1625efc2021-06-10 18:24:34 +0100231TEST_CASE("CreateMultiplicationInt16Workload")
Sadik Armagan2999a022019-04-09 14:20:12 +0100232{
Finn Williamscbd2c232020-06-22 15:58:32 +0100233 RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
Sadik Armagan2999a022019-04-09 14:20:12 +0100234 MultiplicationQueueDescriptor,
235 MultiplicationLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000236 armnn::DataType::QSymmS16>();
Sadik Armagan2999a022019-04-09 14:20:12 +0100237}
238
Sadik Armagan1625efc2021-06-10 18:24:34 +0100239TEST_CASE("CreateMultiplicationInt32Workload")
Finn Williamscbd2c232020-06-22 15:58:32 +0100240{
241 RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<int32_t>,
242 MultiplicationQueueDescriptor,
243 MultiplicationLayer,
244 armnn::DataType::Signed32>();
245}
246
Sadik Armagan1625efc2021-06-10 18:24:34 +0100247TEST_CASE("CreateDivisionFloat32Workload")
David Beckbc392452018-09-10 14:47:28 +0100248{
Finn Williamscbd2c232020-06-22 15:58:32 +0100249 RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000250 DivisionQueueDescriptor,
251 DivisionLayer,
252 armnn::DataType::Float32>();
David Beckbc392452018-09-10 14:47:28 +0100253}
254
Sadik Armagan1625efc2021-06-10 18:24:34 +0100255TEST_CASE("CreateDivisionFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100256{
Finn Williamscbd2c232020-06-22 15:58:32 +0100257 RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100258 DivisionQueueDescriptor,
259 DivisionLayer,
260 armnn::DataType::Float16>();
261}
262
Sadik Armagan1625efc2021-06-10 18:24:34 +0100263TEST_CASE("CreateDivisionUint8Workload")
David Beckbc392452018-09-10 14:47:28 +0100264{
Finn Williamscbd2c232020-06-22 15:58:32 +0100265 RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000266 DivisionQueueDescriptor,
267 DivisionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000268 armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000269}
270
Sadik Armagan1625efc2021-06-10 18:24:34 +0100271TEST_CASE("CreateDivisionInt16Workload")
Sadik Armagan2999a022019-04-09 14:20:12 +0100272{
Finn Williamscbd2c232020-06-22 15:58:32 +0100273 RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
Sadik Armagan2999a022019-04-09 14:20:12 +0100274 DivisionQueueDescriptor,
275 DivisionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000276 armnn::DataType::QSymmS16>();
Sadik Armagan2999a022019-04-09 14:20:12 +0100277}
278
Sadik Armagan1625efc2021-06-10 18:24:34 +0100279TEST_CASE("CreateDivisionInt32Workload")
Finn Williamscbd2c232020-06-22 15:58:32 +0100280{
281 RefCreateElementwiseWorkloadTest<RefDivisionWorkload<int32_t>,
282 DivisionQueueDescriptor,
283 DivisionLayer,
284 armnn::DataType::Signed32>();
285}
286
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100287template <typename BatchNormalizationWorkloadType, armnn::DataType DataType>
288static void RefCreateBatchNormalizationWorkloadTest(DataLayout dataLayout)
telsoa014fcda012018-03-09 14:13:49 +0000289{
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100290 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100291 RefWorkloadFactory factory = GetFactory();
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100292 auto workload = CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType>(factory,
293 graph,
294 dataLayout);
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100295
296 TensorShape inputShape;
297 TensorShape outputShape;
298
299 switch (dataLayout)
300 {
301 case DataLayout::NHWC:
Nikhil Rajd1340932018-10-18 14:27:50 +0100302 inputShape = { 2, 4, 4, 3 };
303 outputShape = { 2, 4, 4, 3 };
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100304 break;
305 case DataLayout::NCHW:
306 default:
Nikhil Rajd1340932018-10-18 14:27:50 +0100307 inputShape = { 2, 3, 4, 4 };
308 outputShape = { 2, 3, 4, 4 };
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100309 break;
310 }
telsoa014fcda012018-03-09 14:13:49 +0000311
telsoa01c577f2c2018-08-31 09:22:23 +0100312 // Checks that outputs and inputs are as we expect them (see definition of CreateBatchNormalizationWorkloadTest).
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100313 CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
314}
315
Sadik Armagan1625efc2021-06-10 18:24:34 +0100316TEST_CASE("CreateBatchNormalizationWithBlobFloat32Workload")
Keith Davisdf04d232020-10-23 17:20:05 +0100317{
318 Graph graph;
319 RefWorkloadFactory factory = GetFactory();
320 auto dataType = armnn::DataType::Float32;
321 auto workload = CreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload,
322 armnn::DataType::Float32>(factory, graph, DataLayout::NHWC);
323
324 TensorShape inputShape;
325 TensorShape outputShape;
326
327 inputShape = { 2, 4, 4, 3 };
328 outputShape = { 2, 4, 4, 3 };
329
330 // Checks that outputs and inputs are as we expect them (see definition of CreateBatchNormalizationWorkloadTest).
331 CheckInputOutput(std::move(workload), TensorInfo(inputShape, dataType), TensorInfo(outputShape, dataType));
332}
333
Sadik Armagan1625efc2021-06-10 18:24:34 +0100334TEST_CASE("CreateBatchNormalizationFloat32Workload")
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100335{
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100336 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload,armnn::DataType::Float32>
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100337 (DataLayout::NCHW);
338}
339
Sadik Armagan1625efc2021-06-10 18:24:34 +0100340TEST_CASE("CreateBatchNormalizationFloat32WorkloadNhwc")
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100341{
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100342 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::Float32>
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100343 (DataLayout::NHWC);
344}
345
Sadik Armagan1625efc2021-06-10 18:24:34 +0100346TEST_CASE("CreateBatchNormalizationFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100347{
348 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload,armnn::DataType::Float16>
349 (DataLayout::NCHW);
350}
351
Sadik Armagan1625efc2021-06-10 18:24:34 +0100352TEST_CASE("CreateBatchNormalizationFloat16WorkloadNhwc")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100353{
354 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::Float16>
355 (DataLayout::NHWC);
356}
357
Sadik Armagan1625efc2021-06-10 18:24:34 +0100358TEST_CASE("CreateBatchNormalizationUint8Workload")
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100359{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000360 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QAsymmU8>
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100361 (DataLayout::NCHW);
362}
363
Sadik Armagan1625efc2021-06-10 18:24:34 +0100364TEST_CASE("CreateBatchNormalizationUint8WorkloadNhwc")
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100365{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000366 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QAsymmU8>
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100367 (DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000368}
369
Sadik Armagan1625efc2021-06-10 18:24:34 +0100370TEST_CASE("CreateBatchNormalizationInt16Workload")
Matteo Martincighf5507132019-06-04 10:59:47 +0100371{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000372 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QSymmS16>
Matteo Martincighf5507132019-06-04 10:59:47 +0100373 (DataLayout::NCHW);
374}
375
Sadik Armagan1625efc2021-06-10 18:24:34 +0100376TEST_CASE("CreateBatchNormalizationInt16WorkloadNhwc")
Matteo Martincighf5507132019-06-04 10:59:47 +0100377{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000378 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QSymmS16>
Matteo Martincighf5507132019-06-04 10:59:47 +0100379 (DataLayout::NHWC);
380}
381
Sadik Armagan1625efc2021-06-10 18:24:34 +0100382TEST_CASE("CreateConvertFp16ToFp32Float32Workload")
telsoa01c577f2c2018-08-31 09:22:23 +0100383{
384 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100385 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100386 auto workload = CreateConvertFp16ToFp32WorkloadTest<RefConvertFp16ToFp32Workload>(factory, graph);
387
388 // Checks that outputs and inputs are as we expect them
389 CheckInputOutput(
390 std::move(workload), TensorInfo({1, 3, 2, 3}, DataType::Float16), TensorInfo({1, 3, 2, 3}, DataType::Float32));
391}
392
Sadik Armagan1625efc2021-06-10 18:24:34 +0100393TEST_CASE("CreateConvertFp32ToFp16Float16Workload")
telsoa01c577f2c2018-08-31 09:22:23 +0100394{
395 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100396 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100397 auto workload = CreateConvertFp32ToFp16WorkloadTest<RefConvertFp32ToFp16Workload>(factory, graph);
398
399 // Checks that outputs and inputs are as we expect them
400 CheckInputOutput(
401 std::move(workload), TensorInfo({1, 3, 2, 3}, DataType::Float32), TensorInfo({1, 3, 2, 3}, DataType::Float16));
402}
403
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100404static void RefCreateConvolution2dWorkloadTest(DataLayout dataLayout = DataLayout::NCHW)
telsoa014fcda012018-03-09 14:13:49 +0000405{
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100406 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100407 RefWorkloadFactory factory = GetFactory();
Mike Kelly9b398322019-05-22 17:21:49 +0100408 auto workload = CreateConvolution2dWorkloadTest<RefConvolution2dWorkload, DataType::Float32>
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100409 (factory, graph, dataLayout);
410
Mike Kellydb482882019-06-14 12:35:24 +0100411 TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({2, 3, 8, 16})
412 : std::initializer_list<unsigned int>({2, 8, 16, 3});
413 TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({2, 2, 2, 10})
414 : std::initializer_list<unsigned int>({2, 2, 10, 2});
telsoa014fcda012018-03-09 14:13:49 +0000415
telsoa01c577f2c2018-08-31 09:22:23 +0100416 // Checks that outputs and inputs are as we expect them (see definition of CreateConvolution2dWorkloadTest).
telsoa014fcda012018-03-09 14:13:49 +0000417 CheckInputOutput(std::move(workload),
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100418 TensorInfo(inputShape, DataType::Float32),
419 TensorInfo(outputShape, DataType::Float32));
420}
421
Sadik Armagan1625efc2021-06-10 18:24:34 +0100422TEST_CASE("CreateConvolution2dFloatNchwWorkload")
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100423{
424 RefCreateConvolution2dWorkloadTest(DataLayout::NCHW);
425}
426
Sadik Armagan1625efc2021-06-10 18:24:34 +0100427TEST_CASE("CreateConvolution2dFloatNhwcWorkload")
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100428{
429 RefCreateConvolution2dWorkloadTest(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000430}
431
Sadik Armagan1625efc2021-06-10 18:24:34 +0100432TEST_CASE("CreateConvolution2dWithBlobWorkload")
Keith Davisdf04d232020-10-23 17:20:05 +0100433{
434 DataLayout dataLayout = DataLayout::NHWC;
435 Graph graph;
436 RefWorkloadFactory factory = GetFactory();
437 auto workload = CreateConvolution2dFusedActivationWithBlobWorkloadTest<RefConvolution2dWorkload, DataType::Float32>
438 (factory, graph, dataLayout);
439
440 TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({2, 3, 8, 16})
441 : std::initializer_list<unsigned int>({2, 8, 16, 3});
442 TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({2, 2, 2, 10})
443 : std::initializer_list<unsigned int>({2, 2, 10, 2});
444
445 // Checks that outputs and inputs are as we expect them (see definition of CreateConvolution2dWorkloadTest).
446 CheckInputOutput(std::move(workload),
447 TensorInfo(inputShape, DataType::Float32),
448 TensorInfo(outputShape, DataType::Float32));
449}
450
Ruomei Yan495852f2019-05-23 11:37:33 +0100451static void RefCreateDepthwiseConvolutionWorkloadTest(DataLayout dataLayout)
452{
453 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100454 RefWorkloadFactory factory = GetFactory();
Ruomei Yan495852f2019-05-23 11:37:33 +0100455 auto workload = CreateDepthwiseConvolution2dWorkloadTest<RefDepthwiseConvolution2dWorkload, DataType::Float32>
456 (factory, graph, dataLayout);
457
Mike Kellydb482882019-06-14 12:35:24 +0100458 TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({ 2, 2, 5, 5 })
459 : std::initializer_list<unsigned int>({ 2, 5, 5, 2 });
460 TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({ 2, 2, 5, 5 })
461 : std::initializer_list<unsigned int>({ 2, 5, 5, 2 });
462
Ruomei Yan495852f2019-05-23 11:37:33 +0100463 // Checks that inputs/outputs are as we expect them (see definition of CreateDepthwiseConvolution2dWorkloadTest).
464 CheckInputOutput(std::move(workload),
465 TensorInfo(inputShape, DataType::Float32),
466 TensorInfo(outputShape, DataType::Float32));
467}
468
Sadik Armagan1625efc2021-06-10 18:24:34 +0100469TEST_CASE("CreateDepthwiseConvolutionFloat32NhwcWorkload")
Ruomei Yan495852f2019-05-23 11:37:33 +0100470{
471 RefCreateDepthwiseConvolutionWorkloadTest(DataLayout::NHWC);
472}
473
Sadik Armagan1625efc2021-06-10 18:24:34 +0100474TEST_CASE("RefCreateFullyConnectedWithBlobWorkloadTest")
Keith Davisdf04d232020-10-23 17:20:05 +0100475{
476 Graph graph;
477 RefWorkloadFactory factory = GetFactory();
478 auto workload = CreateFullyConnectedWithBlobWorkloadTest<RefFullyConnectedWorkload,
479 armnn::DataType::Float32>(factory, graph);
480
481 // Checks that outputs and inputs are as we expect them (see definition of CreateFullyConnectedWorkloadTest).
Teresa Charlinacb3ec52023-04-03 19:57:00 +0100482 float inputsQScale = 1.0f;
483 float outputQScale = 1.0f;
Keith Davisdf04d232020-10-23 17:20:05 +0100484 CheckInputOutput(std::move(workload),
485 TensorInfo({ 3, 1, 4, 5 }, armnn::DataType::Float32, inputsQScale),
486 TensorInfo({ 3, 7 }, armnn::DataType::Float32, outputQScale));
487}
488
Matthew Sloyan81beae32021-07-13 19:46:11 +0100489TEST_CASE("CreateFullyConnectedWorkloadWeightsBiasesAsInputsFloat32")
490{
491 Graph graph;
492 RefWorkloadFactory factory = GetFactory();
493
494 auto workload =
495 CreateFullyConnectedWorkloadWeightsBiasesAsInputsTest<RefFullyConnectedWorkload,
496 armnn::DataType::Float32>(factory, graph);
497
498 // Checks that outputs and inputs are as we expect them (see definition of CreateFullyConnectedWorkloadTest).
Teresa Charlinacb3ec52023-04-03 19:57:00 +0100499 float inputsQScale = 1.0f;
500 float outputQScale = 1.0f;
Matthew Sloyan81beae32021-07-13 19:46:11 +0100501 CheckInputsOutput(std::move(workload),
502 TensorInfo({ 3, 1, 4, 5 }, armnn::DataType::Float32, inputsQScale),
503 TensorInfo({ 7, 20 }, armnn::DataType::Float32, inputsQScale),
504 TensorInfo({ 3, 7 }, armnn::DataType::Float32, outputQScale));
505}
506
telsoa01c577f2c2018-08-31 09:22:23 +0100507template <typename FullyConnectedWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +0000508static void RefCreateFullyConnectedWorkloadTest()
509{
510 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100511 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100512 auto workload = CreateFullyConnectedWorkloadTest<FullyConnectedWorkloadType, DataType>(factory, graph);
telsoa014fcda012018-03-09 14:13:49 +0000513
telsoa01c577f2c2018-08-31 09:22:23 +0100514 // Checks that outputs and inputs are as we expect them (see definition of CreateFullyConnectedWorkloadTest).
Teresa Charlinacb3ec52023-04-03 19:57:00 +0100515 float inputsQScale = DataType == armnn::DataType::QAsymmU8 ? 1.0f : 1.0f;
516 float outputQScale = DataType == armnn::DataType::QAsymmU8 ? 2.0f : 1.0f;
telsoa014fcda012018-03-09 14:13:49 +0000517 CheckInputOutput(std::move(workload),
telsoa01c577f2c2018-08-31 09:22:23 +0100518 TensorInfo({ 3, 1, 4, 5 }, DataType, inputsQScale),
519 TensorInfo({ 3, 7 }, DataType, outputQScale));
telsoa014fcda012018-03-09 14:13:49 +0000520}
521
Sadik Armagan1625efc2021-06-10 18:24:34 +0100522TEST_CASE("CreateFullyConnectedWorkloadFloat32")
telsoa014fcda012018-03-09 14:13:49 +0000523{
Francis Murtagh43aec582019-05-27 12:14:10 +0100524 RefCreateFullyConnectedWorkloadTest<RefFullyConnectedWorkload, armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000525}
526
Sadik Armagan1625efc2021-06-10 18:24:34 +0100527TEST_CASE("CreateFullyConnectedWorkloadQuantisedAsymm8")
telsoa014fcda012018-03-09 14:13:49 +0000528{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000529 RefCreateFullyConnectedWorkloadTest<RefFullyConnectedWorkload, armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000530}
531
Sadik Armagan1625efc2021-06-10 18:24:34 +0100532TEST_CASE("CreateFullyConnectedWorkloadQuantisedSymm16")
Francis Murtagh46c09d02019-05-28 08:15:28 +0100533{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000534 RefCreateFullyConnectedWorkloadTest<RefFullyConnectedWorkload, armnn::DataType::QSymmS16>();
Francis Murtagh46c09d02019-05-28 08:15:28 +0100535}
536
narpra0155a97bc2018-10-02 14:35:53 +0100537template <typename NormalizationWorkloadType, armnn::DataType DataType>
Matteo Martincigha160b242018-10-18 10:33:23 +0100538static void RefCreateNormalizationWorkloadTest(DataLayout dataLayout)
telsoa014fcda012018-03-09 14:13:49 +0000539{
narpra0155a97bc2018-10-02 14:35:53 +0100540 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100541 RefWorkloadFactory factory = GetFactory();
Matteo Martincigha160b242018-10-18 10:33:23 +0100542 auto workload = CreateNormalizationWorkloadTest<NormalizationWorkloadType, DataType>(factory, graph, dataLayout);
543
544 TensorShape inputShape;
545 TensorShape outputShape;
546
547 switch (dataLayout)
548 {
549 case DataLayout::NHWC:
550 inputShape = { 3, 1, 5, 5 };
551 outputShape = { 3, 1, 5, 5 };
552 break;
553 case DataLayout::NCHW:
554 default:
555 inputShape = { 3, 5, 5, 1 };
556 outputShape = { 3, 5, 5, 1 };
557 break;
558 }
telsoa014fcda012018-03-09 14:13:49 +0000559
telsoa01c577f2c2018-08-31 09:22:23 +0100560 // Checks that outputs and inputs are as we expect them (see definition of CreateNormalizationWorkloadTest).
Matteo Martincigha160b242018-10-18 10:33:23 +0100561 CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
narpra0155a97bc2018-10-02 14:35:53 +0100562}
563
Sadik Armagan1625efc2021-06-10 18:24:34 +0100564TEST_CASE("CreateRefNormalizationFloat32NchwWorkload")
narpra0155a97bc2018-10-02 14:35:53 +0100565{
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100566 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
Matteo Martincigha160b242018-10-18 10:33:23 +0100567}
568
Sadik Armagan1625efc2021-06-10 18:24:34 +0100569TEST_CASE("CreateRefNormalizationFloat32NhwcWorkload")
Matteo Martincigha160b242018-10-18 10:33:23 +0100570{
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100571 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
572}
573
Sadik Armagan1625efc2021-06-10 18:24:34 +0100574TEST_CASE("CreateRefNormalizationUint8NchwWorkload")
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100575{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000576 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QAsymmU8>(DataLayout::NCHW);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100577}
578
Sadik Armagan1625efc2021-06-10 18:24:34 +0100579TEST_CASE("CreateRefNormalizationUint8NhwcWorkload")
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100580{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000581 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QAsymmU8>(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000582}
583
Sadik Armagan1625efc2021-06-10 18:24:34 +0100584TEST_CASE("CreateRefNormalizationInt16NchwWorkload")
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100585{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000586 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QSymmS16>(DataLayout::NCHW);
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100587}
588
Sadik Armagan1625efc2021-06-10 18:24:34 +0100589TEST_CASE("CreateRefNormalizationInt16NhwcWorkload")
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100590{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000591 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QSymmS16>(DataLayout::NHWC);
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100592}
593
telsoa01c577f2c2018-08-31 09:22:23 +0100594template <typename Pooling2dWorkloadType, armnn::DataType DataType>
James Conroy69482272018-10-19 10:41:35 +0100595static void RefCreatePooling2dWorkloadTest(DataLayout dataLayout)
telsoa014fcda012018-03-09 14:13:49 +0000596{
597 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100598 RefWorkloadFactory factory = GetFactory();
James Conroy69482272018-10-19 10:41:35 +0100599 auto workload = CreatePooling2dWorkloadTest<Pooling2dWorkloadType, DataType>(factory, graph, dataLayout);
600
601 TensorShape inputShape;
602 TensorShape outputShape;
603
604 switch (dataLayout)
605 {
606 case DataLayout::NHWC:
607 inputShape = { 3, 5, 5, 2 };
608 outputShape = { 3, 2, 4, 2 };
609 break;
610 case DataLayout::NCHW:
611 default:
612 inputShape = { 3, 2, 5, 5 };
613 outputShape = { 3, 2, 2, 4 };
614 }
telsoa014fcda012018-03-09 14:13:49 +0000615
telsoa01c577f2c2018-08-31 09:22:23 +0100616 // Checks that outputs and inputs are as we expect them (see definition of CreatePooling2dWorkloadTest).
James Conroy69482272018-10-19 10:41:35 +0100617 CheckInputOutput(std::move(workload),
618 TensorInfo(inputShape, DataType),
619 TensorInfo(outputShape, DataType));
telsoa014fcda012018-03-09 14:13:49 +0000620}
621
Sadik Armagan1625efc2021-06-10 18:24:34 +0100622TEST_CASE("CreatePooling2dFloat32Workload")
telsoa014fcda012018-03-09 14:13:49 +0000623{
Teresa Charlina3b20472019-06-06 11:12:32 +0100624 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
James Conroy69482272018-10-19 10:41:35 +0100625}
626
Sadik Armagan1625efc2021-06-10 18:24:34 +0100627TEST_CASE("CreatePooling2dFloat32NhwcWorkload")
James Conroy69482272018-10-19 10:41:35 +0100628{
Teresa Charlina3b20472019-06-06 11:12:32 +0100629 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000630}
631
Sadik Armagan1625efc2021-06-10 18:24:34 +0100632TEST_CASE("CreatePooling2dUint8Workload")
telsoa014fcda012018-03-09 14:13:49 +0000633{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000634 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QAsymmU8>(DataLayout::NCHW);
James Conroy69482272018-10-19 10:41:35 +0100635}
636
Sadik Armagan1625efc2021-06-10 18:24:34 +0100637TEST_CASE("CreatePooling2dUint8NhwcWorkload")
James Conroy69482272018-10-19 10:41:35 +0100638{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000639 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QAsymmU8>(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000640}
641
Sadik Armagan1625efc2021-06-10 18:24:34 +0100642TEST_CASE("CreatePooling2dInt16Workload")
Teresa Charlin0434df62019-06-06 13:40:35 +0100643{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000644 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QSymmS16>(DataLayout::NCHW);
Teresa Charlin0434df62019-06-06 13:40:35 +0100645}
646
Sadik Armagan1625efc2021-06-10 18:24:34 +0100647TEST_CASE("CreatePooling2dInt16NhwcWorkload")
Teresa Charlin0434df62019-06-06 13:40:35 +0100648{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000649 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QSymmS16>(DataLayout::NHWC);
Teresa Charlin0434df62019-06-06 13:40:35 +0100650}
651
telsoa01c577f2c2018-08-31 09:22:23 +0100652template <typename SoftmaxWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +0000653static void RefCreateSoftmaxWorkloadTest()
654{
655 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100656 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100657 auto workload = CreateSoftmaxWorkloadTest<SoftmaxWorkloadType, DataType>(factory, graph);
telsoa014fcda012018-03-09 14:13:49 +0000658
telsoa01c577f2c2018-08-31 09:22:23 +0100659 // Checks that outputs and inputs are as we expect them (see definition of CreateSoftmaxWorkloadTest).
Sadik Armaganbe88a572020-04-30 11:39:37 +0100660
661 armnn::TensorInfo tensorInfo({4, 1}, DataType);
662 if (DataType == armnn::DataType::QAsymmU8)
663 {
664 tensorInfo.SetQuantizationOffset(0);
665 tensorInfo.SetQuantizationScale(1.f / 256);
666 }
667 else if (DataType == armnn::DataType::QAsymmS8)
668 {
669 tensorInfo.SetQuantizationOffset(-128);
670 tensorInfo.SetQuantizationScale(1.f / 256);
671 }
telsoa014fcda012018-03-09 14:13:49 +0000672 CheckInputOutput(
673 std::move(workload),
Sadik Armaganbe88a572020-04-30 11:39:37 +0100674 tensorInfo,
675 tensorInfo);
telsoa014fcda012018-03-09 14:13:49 +0000676}
677
Sadik Armagan1625efc2021-06-10 18:24:34 +0100678TEST_CASE("CreateSoftmaxFloat32Workload")
telsoa014fcda012018-03-09 14:13:49 +0000679{
nikraj01a121de32019-05-29 10:51:05 +0100680 RefCreateSoftmaxWorkloadTest<RefSoftmaxWorkload, armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000681}
682
Sadik Armagan1625efc2021-06-10 18:24:34 +0100683TEST_CASE("CreateSoftmaxFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100684{
685 RefCreateSoftmaxWorkloadTest<RefSoftmaxWorkload, armnn::DataType::Float16>();
686}
687
Sadik Armagan1625efc2021-06-10 18:24:34 +0100688TEST_CASE("CreateSoftmaxQuantisedAsymm8Workload")
telsoa014fcda012018-03-09 14:13:49 +0000689{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000690 RefCreateSoftmaxWorkloadTest<RefSoftmaxWorkload, armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000691}
692
Sadik Armagan1625efc2021-06-10 18:24:34 +0100693TEST_CASE("CreateSoftmaxQuantisedSymm16Workload")
nikraj01248683f2019-05-29 16:46:50 +0100694{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000695 RefCreateSoftmaxWorkloadTest<RefSoftmaxWorkload, armnn::DataType::QSymmS16>();
nikraj01248683f2019-05-29 16:46:50 +0100696}
697
telsoa01c577f2c2018-08-31 09:22:23 +0100698template <typename SplitterWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +0000699static void RefCreateSplitterWorkloadTest()
700{
701 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100702 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100703 auto workload = CreateSplitterWorkloadTest<SplitterWorkloadType, DataType>(factory, graph);
telsoa014fcda012018-03-09 14:13:49 +0000704
telsoa01c577f2c2018-08-31 09:22:23 +0100705 // Checks that outputs are as we expect them (see definition of CreateSplitterWorkloadTest).
telsoa014fcda012018-03-09 14:13:49 +0000706 SplitterQueueDescriptor queueDescriptor = workload->GetData();
Jan Eilersbb446e52020-04-02 13:56:54 +0100707 auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100708 CHECK((inputHandle->GetTensorInfo() == TensorInfo({ 5, 7, 7 }, DataType)));
surmeh013537c2c2018-05-18 16:31:43 +0100709
Jan Eilersbb446e52020-04-02 13:56:54 +0100710 auto outputHandle0 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100711 CHECK((outputHandle0->GetTensorInfo() == TensorInfo({ 1, 7, 7 }, DataType)));
surmeh013537c2c2018-05-18 16:31:43 +0100712
Jan Eilersbb446e52020-04-02 13:56:54 +0100713 auto outputHandle1 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[1]);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100714 CHECK((outputHandle1->GetTensorInfo() == TensorInfo({ 2, 7, 7 }, DataType)));
surmeh013537c2c2018-05-18 16:31:43 +0100715
Jan Eilersbb446e52020-04-02 13:56:54 +0100716 auto outputHandle2 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[2]);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100717 CHECK((outputHandle2->GetTensorInfo() == TensorInfo({ 2, 7, 7 }, DataType)));
telsoa014fcda012018-03-09 14:13:49 +0000718}
719
Sadik Armagan1625efc2021-06-10 18:24:34 +0100720TEST_CASE("CreateSplitterFloat32Workload")
telsoa014fcda012018-03-09 14:13:49 +0000721{
Ruomei Yan25339c32019-05-28 16:48:20 +0100722 RefCreateSplitterWorkloadTest<RefSplitterWorkload, armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000723}
724
Sadik Armagan1625efc2021-06-10 18:24:34 +0100725TEST_CASE("CreateSplitterFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100726{
727 RefCreateSplitterWorkloadTest<RefSplitterWorkload, armnn::DataType::Float16>();
728}
729
Sadik Armagan1625efc2021-06-10 18:24:34 +0100730TEST_CASE("CreateSplitterUint8Workload")
telsoa014fcda012018-03-09 14:13:49 +0000731{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000732 RefCreateSplitterWorkloadTest<RefSplitterWorkload, armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000733}
734
Jim Flynne242f2d2019-05-22 14:24:13 +0100735template <typename SplitterWorkloadType, typename ConcatWorkloadType, armnn::DataType DataType>
736static void RefCreateSplitterConcatWorkloadTest()
telsoa014fcda012018-03-09 14:13:49 +0000737{
telsoa01c577f2c2018-08-31 09:22:23 +0100738 // Tests that it is possible to decide which output of the splitter layer
Jim Flynne242f2d2019-05-22 14:24:13 +0100739 // should be lined to which input of the concat layer.
telsoa01c577f2c2018-08-31 09:22:23 +0100740 // We tested that is is possible to specify 0th output
Jim Flynne242f2d2019-05-22 14:24:13 +0100741 // of the splitter to be the 1st input to the concat and the 1st output of the splitter to be 0th input
742 // of the concat.
telsoa014fcda012018-03-09 14:13:49 +0000743
744 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100745 RefWorkloadFactory factory = GetFactory();
Jim Flynne242f2d2019-05-22 14:24:13 +0100746 auto workloads = CreateSplitterConcatWorkloadTest<SplitterWorkloadType, ConcatWorkloadType, DataType>
747 (factory, graph);
telsoa014fcda012018-03-09 14:13:49 +0000748
749 auto wlSplitter = std::move(workloads.first);
Jim Flynne242f2d2019-05-22 14:24:13 +0100750 auto wlConcat = std::move(workloads.second);
telsoa014fcda012018-03-09 14:13:49 +0000751
telsoa01c577f2c2018-08-31 09:22:23 +0100752 //Checks that the index of inputs/outputs matches what we declared on InputDescriptor construction.
Matthew Bentham4cefc412019-06-18 16:14:34 +0100753 armnn::RefTensorHandle* sOut0 = dynamic_cast<armnn::RefTensorHandle*>(wlSplitter->GetData().m_Outputs[0]);
754 armnn::RefTensorHandle* sOut1 = dynamic_cast<armnn::RefTensorHandle*>(wlSplitter->GetData().m_Outputs[1]);
755 armnn::RefTensorHandle* mIn0 = dynamic_cast<armnn::RefTensorHandle*>(wlConcat->GetData().m_Inputs[0]);
756 armnn::RefTensorHandle* mIn1 = dynamic_cast<armnn::RefTensorHandle*>(wlConcat->GetData().m_Inputs[1]);
telsoa014fcda012018-03-09 14:13:49 +0000757
Sadik Armagan1625efc2021-06-10 18:24:34 +0100758 CHECK(sOut0);
759 CHECK(sOut1);
760 CHECK(mIn0);
761 CHECK(mIn1);
telsoa014fcda012018-03-09 14:13:49 +0000762
763 bool validDataPointers = (sOut0 == mIn1) && (sOut1 == mIn0);
764
Sadik Armagan1625efc2021-06-10 18:24:34 +0100765 CHECK(validDataPointers);
telsoa014fcda012018-03-09 14:13:49 +0000766}
767
Sadik Armagan1625efc2021-06-10 18:24:34 +0100768TEST_CASE("CreateSplitterConcatFloat32")
telsoa014fcda012018-03-09 14:13:49 +0000769{
Ruomei Yan25339c32019-05-28 16:48:20 +0100770 RefCreateSplitterConcatWorkloadTest<RefSplitterWorkload, RefConcatWorkload, DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000771}
772
Sadik Armagan1625efc2021-06-10 18:24:34 +0100773TEST_CASE("CreateSplitterConcatFloat16")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100774{
775 RefCreateSplitterConcatWorkloadTest<RefSplitterWorkload, RefConcatWorkload, DataType::Float16>();
776}
777
Sadik Armagan1625efc2021-06-10 18:24:34 +0100778TEST_CASE("CreateSplitterConcatUint8")
telsoa014fcda012018-03-09 14:13:49 +0000779{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000780 RefCreateSplitterConcatWorkloadTest<RefSplitterWorkload, RefConcatWorkload, DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000781}
782
telsoa01c577f2c2018-08-31 09:22:23 +0100783template <typename SplitterWorkloadType, typename ActivationWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +0000784static void RefCreateSingleOutputMultipleInputsTest()
785{
telsoa01c577f2c2018-08-31 09:22:23 +0100786 // Tests that it is possible to assign multiple (two) different layers to each of the outputs of a splitter layer.
787 // We created a splitter with two outputs. That each of those outputs is used by two different activation layers.
telsoa014fcda012018-03-09 14:13:49 +0000788
789 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100790 RefWorkloadFactory factory = GetFactory();
telsoa014fcda012018-03-09 14:13:49 +0000791 std::unique_ptr<SplitterWorkloadType> wlSplitter;
792 std::unique_ptr<ActivationWorkloadType> wlActiv0_0;
793 std::unique_ptr<ActivationWorkloadType> wlActiv0_1;
794 std::unique_ptr<ActivationWorkloadType> wlActiv1_0;
795 std::unique_ptr<ActivationWorkloadType> wlActiv1_1;
796
797 CreateSplitterMultipleInputsOneOutputWorkloadTest<SplitterWorkloadType,
telsoa01c577f2c2018-08-31 09:22:23 +0100798 ActivationWorkloadType, DataType>(factory, graph, wlSplitter, wlActiv0_0, wlActiv0_1, wlActiv1_0, wlActiv1_1);
telsoa014fcda012018-03-09 14:13:49 +0000799
Matthew Bentham4cefc412019-06-18 16:14:34 +0100800 armnn::RefTensorHandle* sOut0 = dynamic_cast<armnn::RefTensorHandle*>(wlSplitter->GetData().m_Outputs[0]);
801 armnn::RefTensorHandle* sOut1 = dynamic_cast<armnn::RefTensorHandle*>(wlSplitter->GetData().m_Outputs[1]);
802 armnn::RefTensorHandle* activ0_0Im = dynamic_cast<armnn::RefTensorHandle*>(wlActiv0_0->GetData().m_Inputs[0]);
803 armnn::RefTensorHandle* activ0_1Im = dynamic_cast<armnn::RefTensorHandle*>(wlActiv0_1->GetData().m_Inputs[0]);
804 armnn::RefTensorHandle* activ1_0Im = dynamic_cast<armnn::RefTensorHandle*>(wlActiv1_0->GetData().m_Inputs[0]);
805 armnn::RefTensorHandle* activ1_1Im = dynamic_cast<armnn::RefTensorHandle*>(wlActiv1_1->GetData().m_Inputs[0]);
telsoa014fcda012018-03-09 14:13:49 +0000806
807
Sadik Armagan1625efc2021-06-10 18:24:34 +0100808 CHECK(sOut0);
809 CHECK(sOut1);
810 CHECK(activ0_0Im);
811 CHECK(activ0_1Im);
812 CHECK(activ1_0Im);
813 CHECK(activ1_1Im);
telsoa014fcda012018-03-09 14:13:49 +0000814
815 bool validDataPointers = (sOut0 == activ0_0Im) && (sOut0 == activ0_1Im) &&
816 (sOut1 == activ1_0Im) && (sOut1 == activ1_1Im);
817
Sadik Armagan1625efc2021-06-10 18:24:34 +0100818 CHECK(validDataPointers);
telsoa014fcda012018-03-09 14:13:49 +0000819}
820
Sadik Armagan1625efc2021-06-10 18:24:34 +0100821TEST_CASE("CreateSingleOutputMultipleInputsFloat32")
telsoa014fcda012018-03-09 14:13:49 +0000822{
Ruomei Yan25339c32019-05-28 16:48:20 +0100823 RefCreateSingleOutputMultipleInputsTest<RefSplitterWorkload, RefActivationWorkload,
telsoa01c577f2c2018-08-31 09:22:23 +0100824 armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000825}
826
Sadik Armagan1625efc2021-06-10 18:24:34 +0100827TEST_CASE("CreateSingleOutputMultipleInputsUint8")
telsoa014fcda012018-03-09 14:13:49 +0000828{
Ruomei Yan25339c32019-05-28 16:48:20 +0100829 RefCreateSingleOutputMultipleInputsTest<RefSplitterWorkload, RefActivationWorkload,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000830 armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000831}
832
telsoa01c577f2c2018-08-31 09:22:23 +0100833template <typename ResizeBilinearWorkloadType, armnn::DataType DataType>
James Conroy59540822018-10-11 12:39:05 +0100834static void RefCreateResizeBilinearTest(DataLayout dataLayout)
telsoa014fcda012018-03-09 14:13:49 +0000835{
836 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100837 RefWorkloadFactory factory = GetFactory();
James Conroy59540822018-10-11 12:39:05 +0100838 auto workload = CreateResizeBilinearWorkloadTest<ResizeBilinearWorkloadType, DataType>(factory, graph, dataLayout);
839
840 TensorShape inputShape;
841 TensorShape outputShape;
842
843 switch (dataLayout)
844 {
845 case DataLayout::NHWC:
846 inputShape = { 2, 4, 4, 3 };
847 outputShape = { 2, 2, 2, 3 };
848 break;
James Conroy69482272018-10-19 10:41:35 +0100849 case DataLayout::NCHW:
850 default:
James Conroy59540822018-10-11 12:39:05 +0100851 inputShape = { 2, 3, 4, 4 };
852 outputShape = { 2, 3, 2, 2 };
853 }
telsoa014fcda012018-03-09 14:13:49 +0000854
telsoa01c577f2c2018-08-31 09:22:23 +0100855 // Checks that outputs and inputs are as we expect them (see definition of CreateResizeBilinearWorkloadTest).
James Conroy69482272018-10-19 10:41:35 +0100856 CheckInputOutput(std::move(workload),
857 TensorInfo(inputShape, DataType),
858 TensorInfo(outputShape, DataType));
telsoa014fcda012018-03-09 14:13:49 +0000859}
860
Sadik Armagan1625efc2021-06-10 18:24:34 +0100861TEST_CASE("CreateResizeBilinearFloat32")
telsoa014fcda012018-03-09 14:13:49 +0000862{
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100863 RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
telsoa014fcda012018-03-09 14:13:49 +0000864}
865
Sadik Armagan1625efc2021-06-10 18:24:34 +0100866TEST_CASE("CreateResizeBilinearFloat16")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100867{
868 RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::Float16>(DataLayout::NCHW);
869}
870
Sadik Armagan1625efc2021-06-10 18:24:34 +0100871TEST_CASE("CreateResizeBilinearUint8")
telsoa014fcda012018-03-09 14:13:49 +0000872{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000873 RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::QAsymmU8>(DataLayout::NCHW);
James Conroy59540822018-10-11 12:39:05 +0100874}
875
Sadik Armagan1625efc2021-06-10 18:24:34 +0100876TEST_CASE("CreateResizeBilinearQuantisedAsymm16")
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +0100877{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000878 RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::QSymmS16>(DataLayout::NCHW);
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +0100879}
880
Sadik Armagan1625efc2021-06-10 18:24:34 +0100881TEST_CASE("CreateResizeBilinearFloat32Nhwc")
James Conroy59540822018-10-11 12:39:05 +0100882{
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100883 RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000884}
885
Francis Murtagh57f13d52019-06-24 14:24:36 +0100886template <typename BatchToSpaceNdWorkloadType, armnn::DataType DataType>
887static void RefCreateBatchToSpaceNdTest()
888{
889 Graph graph;
890 RefWorkloadFactory factory;
891
892 auto workload = CreateBatchToSpaceNdWorkloadTest<BatchToSpaceNdWorkloadType, DataType>(factory, graph);
893
894 CheckInputOutput(std::move(workload),
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100895 TensorInfo({ 1, 1, 1, 1 }, DataType),
896 TensorInfo({ 1, 1, 1, 1 }, DataType));
Francis Murtagh57f13d52019-06-24 14:24:36 +0100897}
898
Sadik Armagan1625efc2021-06-10 18:24:34 +0100899TEST_CASE("CreateBatchToSpaceNdFloat32")
Francis Murtagh57f13d52019-06-24 14:24:36 +0100900{
901 RefCreateBatchToSpaceNdTest<RefBatchToSpaceNdWorkload, armnn::DataType::Float32>();
902}
903
Sadik Armagan1625efc2021-06-10 18:24:34 +0100904TEST_CASE("CreateBatchToSpaceNdFloat16")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100905{
906 RefCreateBatchToSpaceNdTest<RefBatchToSpaceNdWorkload, armnn::DataType::Float16>();
907}
908
Sadik Armagan1625efc2021-06-10 18:24:34 +0100909TEST_CASE("CreateBatchToSpaceNdUint8")
Francis Murtagh57f13d52019-06-24 14:24:36 +0100910{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000911 RefCreateBatchToSpaceNdTest<RefBatchToSpaceNdWorkload, armnn::DataType::QAsymmU8>();
Francis Murtagh57f13d52019-06-24 14:24:36 +0100912}
913
Sadik Armagan1625efc2021-06-10 18:24:34 +0100914TEST_CASE("CreateBatchToSpaceNdQSymm16")
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100915{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000916 RefCreateBatchToSpaceNdTest<RefBatchToSpaceNdWorkload, armnn::DataType::QSymmS16>();
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100917}
918
Matteo Martincighb63973e2018-10-16 16:23:33 +0100919template <typename L2NormalizationWorkloadType, armnn::DataType DataType>
920static void RefCreateL2NormalizationTest(DataLayout dataLayout)
telsoa014fcda012018-03-09 14:13:49 +0000921{
922 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100923 RefWorkloadFactory factory = GetFactory();
Matteo Martincighb63973e2018-10-16 16:23:33 +0100924 auto workload =
925 CreateL2NormalizationWorkloadTest<L2NormalizationWorkloadType, DataType>(factory, graph, dataLayout);
926
927 TensorShape inputShape;
928 TensorShape outputShape;
929
930 switch (dataLayout)
931 {
932 case DataLayout::NHWC:
933 inputShape = { 5, 50, 67, 20 };
934 outputShape = { 5, 50, 67, 20 };
935 break;
936 case DataLayout::NCHW:
937 default:
938 inputShape = { 5, 20, 50, 67 };
939 outputShape = { 5, 20, 50, 67 };
940 break;
941 }
telsoa014fcda012018-03-09 14:13:49 +0000942
telsoa01c577f2c2018-08-31 09:22:23 +0100943 // Checks that outputs and inputs are as we expect them (see definition of CreateL2NormalizationWorkloadTest).
Matteo Martincighb63973e2018-10-16 16:23:33 +0100944 CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
945}
946
Sadik Armagan1625efc2021-06-10 18:24:34 +0100947TEST_CASE("CreateL2NormalizationFloat32")
Matteo Martincighb63973e2018-10-16 16:23:33 +0100948{
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100949 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
Matteo Martincighb63973e2018-10-16 16:23:33 +0100950}
951
Sadik Armagan1625efc2021-06-10 18:24:34 +0100952TEST_CASE("CreateL2NormalizationFloat32Nhwc")
Matteo Martincighb63973e2018-10-16 16:23:33 +0100953{
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100954 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
955}
956
Sadik Armagan1625efc2021-06-10 18:24:34 +0100957TEST_CASE("CreateL2NormalizationInt16")
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100958{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000959 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QSymmS16>(DataLayout::NCHW);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100960}
961
Sadik Armagan1625efc2021-06-10 18:24:34 +0100962TEST_CASE("CreateL2NormalizationInt16Nhwc")
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100963{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000964 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QSymmS16>(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000965}
966
Sadik Armagan1625efc2021-06-10 18:24:34 +0100967TEST_CASE("CreateL2NormalizationUint8")
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100968{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000969 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QAsymmU8>(DataLayout::NCHW);
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100970}
971
Sadik Armagan1625efc2021-06-10 18:24:34 +0100972TEST_CASE("CreateL2NormalizationUint8Nhwc")
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100973{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000974 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QAsymmU8>(DataLayout::NHWC);
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100975}
976
telsoa01c577f2c2018-08-31 09:22:23 +0100977template <typename ReshapeWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +0000978static void RefCreateReshapeWorkloadTest()
979{
980 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100981 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100982 auto workload = CreateReshapeWorkloadTest<ReshapeWorkloadType, DataType>(factory, graph);
telsoa014fcda012018-03-09 14:13:49 +0000983
telsoa01c577f2c2018-08-31 09:22:23 +0100984 // Checks that outputs and inputs are as we expect them (see definition of CreateReshapeWorkloadTest).
telsoa014fcda012018-03-09 14:13:49 +0000985 CheckInputOutput(
986 std::move(workload),
telsoa01c577f2c2018-08-31 09:22:23 +0100987 TensorInfo({ 4, 1 }, DataType),
988 TensorInfo({ 1, 4 }, DataType));
telsoa014fcda012018-03-09 14:13:49 +0000989}
990
Sadik Armagan1625efc2021-06-10 18:24:34 +0100991TEST_CASE("CreateReshapeWorkloadFloat32")
telsoa014fcda012018-03-09 14:13:49 +0000992{
Nina Drozd2f2778f2019-05-27 10:37:05 +0100993 RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000994}
995
Sadik Armagan1625efc2021-06-10 18:24:34 +0100996TEST_CASE("CreateReshapeWorkloadQuantisedAsymm8")
telsoa014fcda012018-03-09 14:13:49 +0000997{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000998 RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000999}
1000
Sadik Armagan1625efc2021-06-10 18:24:34 +01001001TEST_CASE("CreateReshapeWorkloadQuantisedSymm16")
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001002{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001003 RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::QSymmS16>();
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001004}
1005
Jim Flynne242f2d2019-05-22 14:24:13 +01001006template <typename ConcatWorkloadType, armnn::DataType DataType>
1007static void RefCreateConcatWorkloadTest(const armnn::TensorShape& outputShape,
narpra015cdda352018-11-19 15:30:27 +00001008 unsigned int concatAxis)
1009{
1010 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +01001011 RefWorkloadFactory factory = GetFactory();
Jim Flynne242f2d2019-05-22 14:24:13 +01001012 auto workload = CreateConcatWorkloadTest<ConcatWorkloadType, DataType>(factory, graph, outputShape, concatAxis);
narpra015cdda352018-11-19 15:30:27 +00001013
1014 CheckInputsOutput(std::move(workload),
1015 TensorInfo({ 2, 3, 2, 5 }, DataType),
1016 TensorInfo({ 2, 3, 2, 5 }, DataType),
1017 TensorInfo(outputShape, DataType));
1018}
1019
Sadik Armagan1625efc2021-06-10 18:24:34 +01001020TEST_CASE("CreateConcatDim0Float32Workload")
narpra015cdda352018-11-19 15:30:27 +00001021{
Jim Flynne242f2d2019-05-22 14:24:13 +01001022 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 4, 3, 2, 5 }, 0);
narpra015cdda352018-11-19 15:30:27 +00001023}
1024
Sadik Armagan1625efc2021-06-10 18:24:34 +01001025TEST_CASE("CreateConcatDim0Float16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +01001026{
1027 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float16>({ 4, 3, 2, 5 }, 0);
1028}
1029
Sadik Armagan1625efc2021-06-10 18:24:34 +01001030TEST_CASE("CreateConcatDim0Uint8Workload")
narpra015cdda352018-11-19 15:30:27 +00001031{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001032 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QAsymmU8>({ 4, 3, 2, 5 }, 0);
Jim Flynncbb66aa2019-05-15 13:03:54 +01001033}
1034
Sadik Armagan1625efc2021-06-10 18:24:34 +01001035TEST_CASE("CreateConcatDim0Uint16Workload")
Jim Flynncbb66aa2019-05-15 13:03:54 +01001036{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001037 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QSymmS16>({ 4, 3, 2, 5 }, 0);
narpra015cdda352018-11-19 15:30:27 +00001038}
1039
Sadik Armagan1625efc2021-06-10 18:24:34 +01001040TEST_CASE("CreateConcatDim1Float32Workload")
narpra015cdda352018-11-19 15:30:27 +00001041{
Jim Flynne242f2d2019-05-22 14:24:13 +01001042 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 6, 2, 5 }, 1);
narpra015cdda352018-11-19 15:30:27 +00001043}
1044
Sadik Armagan1625efc2021-06-10 18:24:34 +01001045TEST_CASE("CreateConcatDim1Uint8Workload")
narpra015cdda352018-11-19 15:30:27 +00001046{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001047 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QAsymmU8>({ 2, 6, 2, 5 }, 1);
narpra015cdda352018-11-19 15:30:27 +00001048}
1049
Sadik Armagan1625efc2021-06-10 18:24:34 +01001050TEST_CASE("CreateConcatDim2Float32Workload")
narpra015cdda352018-11-19 15:30:27 +00001051{
Jim Flynne242f2d2019-05-22 14:24:13 +01001052 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 3, 4, 5 }, 2);
narpra015cdda352018-11-19 15:30:27 +00001053}
1054
Sadik Armagan1625efc2021-06-10 18:24:34 +01001055TEST_CASE("CreateConcatDim2Uint8Workload")
narpra015cdda352018-11-19 15:30:27 +00001056{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001057 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QAsymmU8>({ 2, 3, 4, 5 }, 2);
narpra015cdda352018-11-19 15:30:27 +00001058}
1059
Sadik Armagan1625efc2021-06-10 18:24:34 +01001060TEST_CASE("CreateConcatDim3Float32Workload")
narpra015cdda352018-11-19 15:30:27 +00001061{
Jim Flynne242f2d2019-05-22 14:24:13 +01001062 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 3, 2, 10 }, 3);
narpra015cdda352018-11-19 15:30:27 +00001063}
1064
Sadik Armagan1625efc2021-06-10 18:24:34 +01001065TEST_CASE("CreateConcatDim3Uint8Workload")
narpra015cdda352018-11-19 15:30:27 +00001066{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001067 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QAsymmU8>({ 2, 3, 2, 10 }, 3);
narpra015cdda352018-11-19 15:30:27 +00001068}
1069
Nina Drozd58ef2c62019-05-16 12:09:18 +01001070template <typename ConstantWorkloadType, armnn::DataType DataType>
1071static void RefCreateConstantWorkloadTest(const armnn::TensorShape& outputShape)
1072{
1073 armnn::Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +01001074 RefWorkloadFactory factory = GetFactory();
Nina Drozd58ef2c62019-05-16 12:09:18 +01001075 auto workload = CreateConstantWorkloadTest<ConstantWorkloadType, DataType>(factory, graph, outputShape);
1076
1077 // Check output is as expected
1078 auto queueDescriptor = workload->GetData();
Jan Eilersbb446e52020-04-02 13:56:54 +01001079 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +01001080 CHECK((outputHandle->GetTensorInfo() == TensorInfo(outputShape, DataType)));
Nina Drozd58ef2c62019-05-16 12:09:18 +01001081}
1082
Sadik Armagan1625efc2021-06-10 18:24:34 +01001083TEST_CASE("CreateConstantUint8Workload")
Nina Drozd58ef2c62019-05-16 12:09:18 +01001084{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001085 RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::QAsymmU8>({ 2, 3, 2, 10 });
Nina Drozd58ef2c62019-05-16 12:09:18 +01001086}
1087
Sadik Armagan1625efc2021-06-10 18:24:34 +01001088TEST_CASE("CreateConstantInt16Workload")
Nina Drozd58ef2c62019-05-16 12:09:18 +01001089{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001090 RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::QSymmS16>({ 2, 3, 2, 10 });
Nina Drozd58ef2c62019-05-16 12:09:18 +01001091}
1092
Sadik Armagan1625efc2021-06-10 18:24:34 +01001093TEST_CASE("CreateConstantFloat32Workload")
Nina Drozd58ef2c62019-05-16 12:09:18 +01001094{
1095 RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::Float32>({ 2, 3, 2, 10 });
1096}
1097
Sadik Armagan1625efc2021-06-10 18:24:34 +01001098TEST_CASE("CreateConstantSigned32Workload")
Nina Drozd58ef2c62019-05-16 12:09:18 +01001099{
1100 RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::Signed32>({ 2, 3, 2, 10 });
1101}
1102
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001103static void RefCreatePreluWorkloadTest(const armnn::TensorShape& inputShape,
1104 const armnn::TensorShape& alphaShape,
1105 const armnn::TensorShape& outputShape,
1106 armnn::DataType dataType)
Matteo Martincighab9e5252019-06-13 17:27:46 +01001107{
1108 armnn::Graph graph;
1109 RefWorkloadFactory factory;
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001110 auto workload = CreatePreluWorkloadTest<RefPreluWorkload>(factory,
1111 graph,
1112 inputShape,
1113 alphaShape,
1114 outputShape,
1115 dataType);
Matteo Martincighab9e5252019-06-13 17:27:46 +01001116
1117 // Check output is as expected
1118 auto queueDescriptor = workload->GetData();
Jan Eilersbb446e52020-04-02 13:56:54 +01001119 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +01001120 CHECK((outputHandle->GetTensorInfo() == TensorInfo(outputShape, dataType)));
Matteo Martincighab9e5252019-06-13 17:27:46 +01001121}
1122
Sadik Armagan1625efc2021-06-10 18:24:34 +01001123TEST_CASE("CreatePreluFloat32Workload")
Matteo Martincighab9e5252019-06-13 17:27:46 +01001124{
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001125 RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::Float32);
Matteo Martincighab9e5252019-06-13 17:27:46 +01001126}
1127
Sadik Armagan1625efc2021-06-10 18:24:34 +01001128TEST_CASE("CreatePreluFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +01001129{
1130 RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::Float16);
1131}
1132
Sadik Armagan1625efc2021-06-10 18:24:34 +01001133TEST_CASE("CreatePreluUint8Workload")
Matteo Martincighab9e5252019-06-13 17:27:46 +01001134{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001135 RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::QAsymmU8);
Matteo Martincighab9e5252019-06-13 17:27:46 +01001136}
1137
Sadik Armagan1625efc2021-06-10 18:24:34 +01001138TEST_CASE("CreatePreluInt16Workload")
Matteo Martincighab9e5252019-06-13 17:27:46 +01001139{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001140 RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::QSymmS16);
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001141}
1142
Sadik Armagan1625efc2021-06-10 18:24:34 +01001143TEST_CASE("CreatePreluFloat32NoBroadcastWorkload")
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001144{
Sadik Armagan1625efc2021-06-10 18:24:34 +01001145 CHECK_THROWS_AS(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001146 armnn::DataType::Float32),
1147 armnn::InvalidArgumentException);
1148}
1149
Sadik Armagan1625efc2021-06-10 18:24:34 +01001150TEST_CASE("CreatePreluFloat16NoBroadcastWorkload")
Matthew Jackson9bff1442019-09-12 09:08:23 +01001151{
Sadik Armagan1625efc2021-06-10 18:24:34 +01001152 CHECK_THROWS_AS(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
Matthew Jackson9bff1442019-09-12 09:08:23 +01001153 armnn::DataType::Float16),
1154 armnn::InvalidArgumentException);
1155}
1156
Sadik Armagan1625efc2021-06-10 18:24:34 +01001157TEST_CASE("CreatePreluUint8NoBroadcastWorkload")
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001158{
Sadik Armagan1625efc2021-06-10 18:24:34 +01001159 CHECK_THROWS_AS(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
Derek Lambertif90c56d2020-01-10 17:14:08 +00001160 armnn::DataType::QAsymmU8),
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001161 armnn::InvalidArgumentException);
1162}
1163
Sadik Armagan1625efc2021-06-10 18:24:34 +01001164TEST_CASE("CreatePreluInt16NoBroadcastWorkload")
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001165{
Sadik Armagan1625efc2021-06-10 18:24:34 +01001166 CHECK_THROWS_AS(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
Derek Lambertif90c56d2020-01-10 17:14:08 +00001167 armnn::DataType::QSymmS16),
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001168 armnn::InvalidArgumentException);
Matteo Martincighab9e5252019-06-13 17:27:46 +01001169}
1170
James Conroy60597842019-07-02 10:57:56 +01001171template <typename SpaceToDepthWorkloadType, armnn::DataType DataType>
1172static void RefCreateSpaceToDepthWorkloadTest()
1173{
1174 Graph graph;
1175 RefWorkloadFactory factory;
1176
1177 auto workload = CreateSpaceToDepthWorkloadTest<SpaceToDepthWorkloadType, DataType>(factory, graph);
1178
1179 CheckInputOutput(std::move(workload),
1180 TensorInfo({ 1, 2, 2, 1 }, DataType),
1181 TensorInfo({ 1, 1, 1, 4 }, DataType));
1182}
1183
Sadik Armagan1625efc2021-06-10 18:24:34 +01001184TEST_CASE("CreateSpaceToDepthWorkloadFloat32")
James Conroy60597842019-07-02 10:57:56 +01001185{
1186 RefCreateSpaceToDepthWorkloadTest<RefSpaceToDepthWorkload, armnn::DataType::Float32>();
1187}
1188
Sadik Armagan1625efc2021-06-10 18:24:34 +01001189TEST_CASE("CreateSpaceToDepthWorkloadFloat16")
Matthew Jackson9bff1442019-09-12 09:08:23 +01001190{
1191 RefCreateSpaceToDepthWorkloadTest<RefSpaceToDepthWorkload, armnn::DataType::Float16>();
1192}
1193
Sadik Armagan1625efc2021-06-10 18:24:34 +01001194TEST_CASE("CreateSpaceToDepthWorkloadQASymm8")
James Conroy60597842019-07-02 10:57:56 +01001195{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001196 RefCreateSpaceToDepthWorkloadTest<RefSpaceToDepthWorkload, armnn::DataType::QAsymmU8>();
James Conroy60597842019-07-02 10:57:56 +01001197}
1198
Sadik Armagan1625efc2021-06-10 18:24:34 +01001199TEST_CASE("CreateSpaceToDepthWorkloadQSymm16")
James Conroy60597842019-07-02 10:57:56 +01001200{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001201 RefCreateSpaceToDepthWorkloadTest<RefSpaceToDepthWorkload, armnn::DataType::QSymmS16>();
James Conroy60597842019-07-02 10:57:56 +01001202}
1203
Matthew Jacksond5166102019-07-31 14:06:28 +01001204template <armnn::DataType DataType>
Matthew Jackson81e601c2019-07-11 12:07:09 +01001205static void RefCreateStackWorkloadTest(const armnn::TensorShape& inputShape,
1206 const armnn::TensorShape& outputShape,
1207 unsigned int axis,
Matthew Jacksond5166102019-07-31 14:06:28 +01001208 unsigned int numInputs)
Matthew Jackson81e601c2019-07-11 12:07:09 +01001209{
1210 armnn::Graph graph;
1211 RefWorkloadFactory factory;
Matthew Jacksond5166102019-07-31 14:06:28 +01001212 auto workload = CreateStackWorkloadTest<RefStackWorkload, DataType>(factory,
1213 graph,
1214 inputShape,
1215 outputShape,
1216 axis,
1217 numInputs);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001218
Matthew Jacksond5166102019-07-31 14:06:28 +01001219 // Check inputs and output are as expected
1220 StackQueueDescriptor queueDescriptor = workload->GetData();
1221 for (unsigned int i = 0; i < numInputs; ++i)
1222 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001223 auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[i]);
Sadik Armagan1625efc2021-06-10 18:24:34 +01001224 CHECK((inputHandle->GetTensorInfo() == TensorInfo(inputShape, DataType)));
Matthew Jacksond5166102019-07-31 14:06:28 +01001225 }
Jan Eilersbb446e52020-04-02 13:56:54 +01001226 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +01001227 CHECK((outputHandle->GetTensorInfo() == TensorInfo(outputShape, DataType)));
Matthew Jackson81e601c2019-07-11 12:07:09 +01001228}
1229
Sadik Armagan1625efc2021-06-10 18:24:34 +01001230TEST_CASE("CreateStackFloat32Workload")
Matthew Jackson81e601c2019-07-11 12:07:09 +01001231{
Matthew Jacksond5166102019-07-31 14:06:28 +01001232 RefCreateStackWorkloadTest<armnn::DataType::Float32>({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001233}
1234
Sadik Armagan1625efc2021-06-10 18:24:34 +01001235TEST_CASE("CreateStackUint8Workload")
Matthew Jackson81e601c2019-07-11 12:07:09 +01001236{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001237 RefCreateStackWorkloadTest<armnn::DataType::QAsymmU8>({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001238}
1239
Sadik Armagan1625efc2021-06-10 18:24:34 +01001240TEST_CASE("CreateStackUint16Workload")
Matthew Jackson81e601c2019-07-11 12:07:09 +01001241{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001242 RefCreateStackWorkloadTest<armnn::DataType::QSymmS16>({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001243}
1244
James Conroy4f1f8992020-04-29 20:01:10 +01001245template <typename QLstmWorkloadType>
1246static void RefCreateQLstmWorkloadTest()
1247{
1248 Graph graph;
1249 RefWorkloadFactory factory;
1250
1251 auto workload = CreateQLstmWorkloadTest<QLstmWorkloadType>(factory, graph);
1252
1253 armnn::TensorInfo inputInfo({2 , 4}, armnn::DataType::QAsymmS8, 0.0078125f, 0);
1254
1255 armnn::TensorInfo cellStateInfo({2 , 4}, armnn::DataType::QSymmS16, 3.05176e-05f, 0);
1256
1257 armnn::TensorInfo outputInfo({2 , 4}, armnn::DataType::QAsymmS8, 0.007f, 0);
1258
1259 QLstmQueueDescriptor queueDescriptor = workload->GetData();
Jan Eilersaaf9a8f2020-07-01 16:35:35 +01001260 auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
1261 auto cellStateOutHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[1]);
1262 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[2]);
James Conroy4f1f8992020-04-29 20:01:10 +01001263
Sadik Armagan1625efc2021-06-10 18:24:34 +01001264 CHECK((inputHandle->GetTensorInfo() == inputInfo));
1265 CHECK((cellStateOutHandle->GetTensorInfo() == cellStateInfo));
1266 CHECK((outputHandle->GetTensorInfo() == outputInfo));
James Conroy4f1f8992020-04-29 20:01:10 +01001267}
1268
Sadik Armagan1625efc2021-06-10 18:24:34 +01001269TEST_CASE("CreateQLstmWorkload")
James Conroy4f1f8992020-04-29 20:01:10 +01001270{
1271 RefCreateQLstmWorkloadTest<RefQLstmWorkload>();
1272}
1273
Teresa Charlin788e2a62022-01-17 21:19:52 +00001274template <armnn::DataType DataType>
1275static void RefCreateActivationWorkloadReplaceFunctionsTest()
1276{
1277 Graph graph;
1278 RefWorkloadFactory factory = GetFactory();
1279 // input and output are created as armnn::TensorInfo tensorInfo({1, 1}, DataType)
1280 auto workloadPtr = CreateActivationWorkloadTest<RefActivationWorkload, DataType>(factory, graph);
1281
1282 // new input and output tensor handlers are created and then replace in the workload
1283 shared_ptr<RefMemoryManager> memoryManager = make_shared<RefMemoryManager>();
1284 const RefTensorHandleFactory tensorHandleFactory(memoryManager);
Rob Hughes5bcc0722022-01-21 10:56:14 +00001285 TensorInfo inputInfo({2 , 2}, armnn::DataType::Float16);
1286 TensorInfo outputInfo({2 , 2}, armnn::DataType::Float16);
Teresa Charlin788e2a62022-01-17 21:19:52 +00001287 unique_ptr<ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
1288 unique_ptr<ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo);
1289 unsigned int slot = 0;
1290 workloadPtr->ReplaceInputTensorHandle(inputHandle.get(), slot);
1291 workloadPtr->ReplaceOutputTensorHandle(outputHandle.get(), slot);
1292
1293 // Check if the tensor handlers inside the workload are the same as ones we replace with
1294 auto queueDescriptor = workloadPtr->GetData();
1295 auto inputHandleTest = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
1296 auto outputHandleTest = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
1297 CHECK((inputHandleTest->GetTensorInfo() == inputInfo));
1298 CHECK((outputHandleTest->GetTensorInfo() == outputInfo));
1299 CHECK(inputHandle.get() == inputHandleTest);
1300 CHECK(outputHandle.get() == outputHandleTest);
1301 inputHandle->Allocate();
1302 CHECK(inputHandle->Map() == inputHandleTest->Map());
1303 outputHandle->Allocate();
1304 CHECK(outputHandle->Map() == outputHandleTest->Map());
1305}
1306
1307TEST_CASE("ReplaceFunctionsfromFloat32toFloat16ActivationWorkload")
1308{
1309 RefCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::Float32>();
1310}
1311
1312TEST_CASE("ReplaceFunctionsfromUint8toFloat16ActivationWorkload")
1313{
1314 RefCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::QAsymmU8>();
1315}
1316
Mike Kelly4cc341c2023-07-07 15:43:06 +01001317bool TestRefTensorHandleInfo(armnn::RefTensorHandle* handle, const armnn::TensorInfo& expectedInfo)
1318{
1319 const TensorInfo handleInfo = handle->GetTensorInfo();
1320 const TensorInfo expectedAclInfo = expectedInfo;
1321
1322 if (handleInfo.GetDataType() != expectedAclInfo.GetDataType())
1323 {
1324 return false;
1325 }
1326
1327 if (handleInfo.GetNumDimensions() != expectedAclInfo.GetNumDimensions())
1328 {
1329 return false;
1330 }
1331
1332 for (unsigned int d = 0; d < expectedAclInfo.GetNumDimensions(); ++d)
1333 {
1334 if (handleInfo.GetShape()[d] != expectedAclInfo.GetShape()[d])
1335 {
1336 return false;
1337 }
1338 }
1339
1340 return true;
1341}
1342
1343TEST_CASE("RefCreateSplitterWorkload")
1344{
1345 Graph graph;
1346 RefWorkloadFactory factory = GetFactory();
1347
1348 auto workload = CreateSplitterWorkloadTest<RefSplitterWorkload, DataType::Float32>(factory, graph);
1349
1350 // Checks that outputs are as we expect them (see definition of CreateSplitterWorkloadTest).
1351 SplitterQueueDescriptor queueDescriptor = workload->GetData();
1352 auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
1353 CHECK(TestRefTensorHandleInfo(inputHandle, TensorInfo({5, 7, 7}, DataType::Float32)));
1354
1355 auto outputHandle0 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
1356 CHECK(TestRefTensorHandleInfo(outputHandle0, TensorInfo({1, 7, 7}, DataType::Float32)));
1357
1358 auto outputHandle1 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[1]);
1359 CHECK(TestRefTensorHandleInfo(outputHandle1, TensorInfo({2, 7, 7}, DataType::Float32)));
1360
1361 auto outputHandle2 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[2]);
1362 CHECK(TestRefTensorHandleInfo(outputHandle2, TensorInfo({2, 7, 7}, DataType::Float32)));
1363}
1364
Sadik Armagan1625efc2021-06-10 18:24:34 +01001365}