blob: 4293ef54f3f9bf46fa77c2e8428bf0405391fba0 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
arovir0143095f32018-10-09 18:04:24 +01005
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00006#include <test/CreateWorkload.hpp>
arovir0143095f32018-10-09 18:04:24 +01007
Jan Eilersbb446e52020-04-02 13:56:54 +01008#include <armnn/utility/PolymorphicDowncast.hpp>
Matthew Bentham4cefc412019-06-18 16:14:34 +01009#include <reference/RefTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <reference/RefWorkloadFactory.hpp>
11#include <reference/workloads/RefWorkloads.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Sadik Armagan1625efc2021-06-10 18:24:34 +010013#include <doctest/doctest.h>
14
telsoa014fcda012018-03-09 14:13:49 +000015namespace
16{
17
18template<typename Workload>
19void CheckInputOutput(std::unique_ptr<Workload> workload, const TensorInfo& inputInfo, const TensorInfo& outputInfo)
20{
21 auto queueDescriptor = workload->GetData();
Jan Eilersbb446e52020-04-02 13:56:54 +010022 auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
23 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +010024 CHECK((inputHandle->GetTensorInfo() == inputInfo));
25 CHECK((outputHandle->GetTensorInfo() == outputInfo));
telsoa014fcda012018-03-09 14:13:49 +000026}
27
28template <typename Workload>
29void CheckInputsOutput(std::unique_ptr<Workload> workload,
30 const TensorInfo& inputInfo0,
31 const TensorInfo& inputInfo1,
32 const TensorInfo& outputInfo)
33{
34 auto queueDescriptor = workload->GetData();
Jan Eilersbb446e52020-04-02 13:56:54 +010035 auto inputHandle0 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
36 auto inputHandle1 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[1]);
37 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +010038 CHECK((inputHandle0->GetTensorInfo() == inputInfo0));
39 CHECK((inputHandle1->GetTensorInfo() == inputInfo1));
40 CHECK((outputHandle->GetTensorInfo() == outputInfo));
telsoa014fcda012018-03-09 14:13:49 +000041}
Matthew Bentham7c1603a2019-06-21 17:22:23 +010042
43armnn::RefWorkloadFactory GetFactory()
44{
45 std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
46 return RefWorkloadFactory(memoryManager);
47}
48
49
telsoa014fcda012018-03-09 14:13:49 +000050}
51
Sadik Armagan1625efc2021-06-10 18:24:34 +010052TEST_SUITE("CreateWorkloadRef")
53{
telsoa01c577f2c2018-08-31 09:22:23 +010054template <typename ActivationWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +000055static void RefCreateActivationWorkloadTest()
56{
57 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +010058 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +010059 auto workload = CreateActivationWorkloadTest<ActivationWorkloadType, DataType>(factory, graph);
telsoa014fcda012018-03-09 14:13:49 +000060
telsoa01c577f2c2018-08-31 09:22:23 +010061 // Checks that outputs are as we expect them (see definition of CreateActivationWorkloadTest).
telsoa014fcda012018-03-09 14:13:49 +000062 CheckInputOutput(std::move(workload),
telsoa01c577f2c2018-08-31 09:22:23 +010063 TensorInfo({ 1, 1 }, DataType),
64 TensorInfo({ 1, 1 }, DataType));
telsoa014fcda012018-03-09 14:13:49 +000065}
66
Sadik Armagan1625efc2021-06-10 18:24:34 +010067TEST_CASE("CreateActivationFloat32Workload")
telsoa014fcda012018-03-09 14:13:49 +000068{
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +010069 RefCreateActivationWorkloadTest<RefActivationWorkload, armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +000070}
71
Sadik Armagan1625efc2021-06-10 18:24:34 +010072TEST_CASE("CreateActivationUint8Workload")
telsoa014fcda012018-03-09 14:13:49 +000073{
Derek Lambertif90c56d2020-01-10 17:14:08 +000074 RefCreateActivationWorkloadTest<RefActivationWorkload, armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +000075}
76
David Beckbc392452018-09-10 14:47:28 +010077template <typename WorkloadType,
78 typename DescriptorType,
79 typename LayerType,
80 armnn::DataType DataType>
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000081static void RefCreateElementwiseWorkloadTest()
telsoa014fcda012018-03-09 14:13:49 +000082{
83 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +010084 RefWorkloadFactory factory = GetFactory();
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000085 auto workload = CreateElementwiseWorkloadTest<WorkloadType, DescriptorType, LayerType, DataType>(
86 factory, graph);
telsoa014fcda012018-03-09 14:13:49 +000087
telsoa014fcda012018-03-09 14:13:49 +000088 CheckInputsOutput(std::move(workload),
telsoa01c577f2c2018-08-31 09:22:23 +010089 TensorInfo({ 2, 3 }, DataType),
90 TensorInfo({ 2, 3 }, DataType),
91 TensorInfo({ 2, 3 }, DataType));
telsoa014fcda012018-03-09 14:13:49 +000092}
93
Sadik Armagan1625efc2021-06-10 18:24:34 +010094TEST_CASE("CreateSubtractionWorkloadWithBlobTest")
Keith Davisdf04d232020-10-23 17:20:05 +010095{
96 Graph graph;
97 RefWorkloadFactory factory = GetFactory();
98 armnn::DataType DataType = armnn::DataType::Float32;
99
100 auto workload = CreateSubtractionWithBlobWorkloadTest<RefSubtractionWorkload<>,
101 SubtractionQueueDescriptor,
102 armnn::DataType::Float32>
103 (factory, graph);
104
105 CheckInputsOutput(std::move(workload),
106 TensorInfo({ 2, 3 }, DataType),
107 TensorInfo({ 2, 3 }, DataType),
108 TensorInfo({ 2, 3 }, DataType));
109}
110
Sadik Armagan1625efc2021-06-10 18:24:34 +0100111TEST_CASE("CreateAdditionWorkloadWithBlobTest")
Keith Davisdf04d232020-10-23 17:20:05 +0100112{
113 Graph graph;
114 RefWorkloadFactory factory = GetFactory();
115 armnn::DataType DataType = armnn::DataType::Float32;
116
117 auto workload = CreateAdditionWithBlobWorkloadTest<RefAdditionWorkload<>,
118 AdditionQueueDescriptor,
119 armnn::DataType::Float32>(factory, graph);
120
121 CheckInputsOutput(std::move(workload),
122 TensorInfo({ 2, 3 }, DataType),
123 TensorInfo({ 2, 3 }, DataType),
124 TensorInfo({ 2, 3 }, DataType));
125}
126
Sadik Armagan1625efc2021-06-10 18:24:34 +0100127TEST_CASE("CreateMultiplicationWorkloadWithBlobTest")
Keith Davisdf04d232020-10-23 17:20:05 +0100128{
129 Graph graph;
130 RefWorkloadFactory factory = GetFactory();
131 armnn::DataType DataType = armnn::DataType::Float32;
132
133 auto workload = CreateMultiplicationWithBlobWorkloadTest<RefMultiplicationWorkload<>,
134 MultiplicationQueueDescriptor,
135 armnn::DataType::Float32>(factory, graph);
136
137 CheckInputsOutput(std::move(workload),
138 TensorInfo({2, 3}, DataType),
139 TensorInfo({2, 3}, DataType),
140 TensorInfo({2, 3}, DataType));
141}
142
Sadik Armagan1625efc2021-06-10 18:24:34 +0100143TEST_CASE("CreateAdditionFloatWorkload")
telsoa014fcda012018-03-09 14:13:49 +0000144{
Finn Williamscbd2c232020-06-22 15:58:32 +0100145 RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000146 AdditionQueueDescriptor,
147 AdditionLayer,
148 armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000149}
150
Sadik Armagan1625efc2021-06-10 18:24:34 +0100151TEST_CASE("CreateAdditionUint8Workload")
telsoa014fcda012018-03-09 14:13:49 +0000152{
Finn Williamscbd2c232020-06-22 15:58:32 +0100153 RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000154 AdditionQueueDescriptor,
155 AdditionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000156 armnn::DataType::QAsymmU8>();
David Beckbc392452018-09-10 14:47:28 +0100157}
158
Sadik Armagan1625efc2021-06-10 18:24:34 +0100159TEST_CASE("CreateAdditionInt16Workload")
Sadik Armagan2999a022019-04-09 14:20:12 +0100160{
Finn Williamscbd2c232020-06-22 15:58:32 +0100161 RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
Sadik Armagan2999a022019-04-09 14:20:12 +0100162 AdditionQueueDescriptor,
163 AdditionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000164 armnn::DataType::QSymmS16>();
Sadik Armagan2999a022019-04-09 14:20:12 +0100165}
166
Sadik Armagan1625efc2021-06-10 18:24:34 +0100167TEST_CASE("CreateAdditionInt32Workload")
Finn Williamscbd2c232020-06-22 15:58:32 +0100168{
169 RefCreateElementwiseWorkloadTest<RefAdditionWorkload<int32_t>,
170 AdditionQueueDescriptor,
171 AdditionLayer,
172 armnn::DataType::Signed32>();
173}
174
Sadik Armagan1625efc2021-06-10 18:24:34 +0100175TEST_CASE("CreateSubtractionFloat32Workload")
David Beckbc392452018-09-10 14:47:28 +0100176{
Finn Williamscbd2c232020-06-22 15:58:32 +0100177 RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000178 SubtractionQueueDescriptor,
179 SubtractionLayer,
180 armnn::DataType::Float32>();
David Beckbc392452018-09-10 14:47:28 +0100181}
182
Sadik Armagan1625efc2021-06-10 18:24:34 +0100183TEST_CASE("CreateSubtractionFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100184{
Finn Williamscbd2c232020-06-22 15:58:32 +0100185 RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100186 SubtractionQueueDescriptor,
187 SubtractionLayer,
188 armnn::DataType::Float16>();
189}
190
Sadik Armagan1625efc2021-06-10 18:24:34 +0100191TEST_CASE("CreateSubtractionUint8Workload")
David Beckbc392452018-09-10 14:47:28 +0100192{
Finn Williamscbd2c232020-06-22 15:58:32 +0100193 RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000194 SubtractionQueueDescriptor,
195 SubtractionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000196 armnn::DataType::QAsymmU8>();
David Beckbc392452018-09-10 14:47:28 +0100197}
198
Sadik Armagan1625efc2021-06-10 18:24:34 +0100199TEST_CASE("CreateSubtractionInt16Workload")
Sadik Armagan2999a022019-04-09 14:20:12 +0100200{
Finn Williamscbd2c232020-06-22 15:58:32 +0100201 RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
Sadik Armagan2999a022019-04-09 14:20:12 +0100202 SubtractionQueueDescriptor,
203 SubtractionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000204 armnn::DataType::QSymmS16>();
Sadik Armagan2999a022019-04-09 14:20:12 +0100205}
206
Sadik Armagan1625efc2021-06-10 18:24:34 +0100207TEST_CASE("CreateSubtractionInt32Workload")
Finn Williamscbd2c232020-06-22 15:58:32 +0100208{
209 RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<int32_t>,
210 SubtractionQueueDescriptor,
211 SubtractionLayer,
212 armnn::DataType::Signed32>();
213}
214
Sadik Armagan1625efc2021-06-10 18:24:34 +0100215TEST_CASE("CreateMultiplicationFloatWorkload")
David Beckbc392452018-09-10 14:47:28 +0100216{
Finn Williamscbd2c232020-06-22 15:58:32 +0100217 RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000218 MultiplicationQueueDescriptor,
219 MultiplicationLayer,
220 armnn::DataType::Float32>();
David Beckbc392452018-09-10 14:47:28 +0100221}
222
Sadik Armagan1625efc2021-06-10 18:24:34 +0100223TEST_CASE("CreateMultiplicationUint8Workload")
David Beckbc392452018-09-10 14:47:28 +0100224{
Finn Williamscbd2c232020-06-22 15:58:32 +0100225 RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000226 MultiplicationQueueDescriptor,
227 MultiplicationLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000228 armnn::DataType::QAsymmU8>();
David Beckbc392452018-09-10 14:47:28 +0100229}
230
Sadik Armagan1625efc2021-06-10 18:24:34 +0100231TEST_CASE("CreateMultiplicationInt16Workload")
Sadik Armagan2999a022019-04-09 14:20:12 +0100232{
Finn Williamscbd2c232020-06-22 15:58:32 +0100233 RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
Sadik Armagan2999a022019-04-09 14:20:12 +0100234 MultiplicationQueueDescriptor,
235 MultiplicationLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000236 armnn::DataType::QSymmS16>();
Sadik Armagan2999a022019-04-09 14:20:12 +0100237}
238
Sadik Armagan1625efc2021-06-10 18:24:34 +0100239TEST_CASE("CreateMultiplicationInt32Workload")
Finn Williamscbd2c232020-06-22 15:58:32 +0100240{
241 RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<int32_t>,
242 MultiplicationQueueDescriptor,
243 MultiplicationLayer,
244 armnn::DataType::Signed32>();
245}
246
Sadik Armagan1625efc2021-06-10 18:24:34 +0100247TEST_CASE("CreateDivisionFloat32Workload")
David Beckbc392452018-09-10 14:47:28 +0100248{
Finn Williamscbd2c232020-06-22 15:58:32 +0100249 RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000250 DivisionQueueDescriptor,
251 DivisionLayer,
252 armnn::DataType::Float32>();
David Beckbc392452018-09-10 14:47:28 +0100253}
254
Sadik Armagan1625efc2021-06-10 18:24:34 +0100255TEST_CASE("CreateDivisionFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100256{
Finn Williamscbd2c232020-06-22 15:58:32 +0100257 RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100258 DivisionQueueDescriptor,
259 DivisionLayer,
260 armnn::DataType::Float16>();
261}
262
Sadik Armagan1625efc2021-06-10 18:24:34 +0100263TEST_CASE("CreateDivisionUint8Workload")
David Beckbc392452018-09-10 14:47:28 +0100264{
Finn Williamscbd2c232020-06-22 15:58:32 +0100265 RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000266 DivisionQueueDescriptor,
267 DivisionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000268 armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000269}
270
Sadik Armagan1625efc2021-06-10 18:24:34 +0100271TEST_CASE("CreateDivisionInt16Workload")
Sadik Armagan2999a022019-04-09 14:20:12 +0100272{
Finn Williamscbd2c232020-06-22 15:58:32 +0100273 RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
Sadik Armagan2999a022019-04-09 14:20:12 +0100274 DivisionQueueDescriptor,
275 DivisionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000276 armnn::DataType::QSymmS16>();
Sadik Armagan2999a022019-04-09 14:20:12 +0100277}
278
Sadik Armagan1625efc2021-06-10 18:24:34 +0100279TEST_CASE("CreateDivisionInt32Workload")
Finn Williamscbd2c232020-06-22 15:58:32 +0100280{
281 RefCreateElementwiseWorkloadTest<RefDivisionWorkload<int32_t>,
282 DivisionQueueDescriptor,
283 DivisionLayer,
284 armnn::DataType::Signed32>();
285}
286
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100287template <typename BatchNormalizationWorkloadType, armnn::DataType DataType>
288static void RefCreateBatchNormalizationWorkloadTest(DataLayout dataLayout)
telsoa014fcda012018-03-09 14:13:49 +0000289{
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100290 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100291 RefWorkloadFactory factory = GetFactory();
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100292 auto workload = CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType>(factory,
293 graph,
294 dataLayout);
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100295
296 TensorShape inputShape;
297 TensorShape outputShape;
298
299 switch (dataLayout)
300 {
301 case DataLayout::NHWC:
Nikhil Rajd1340932018-10-18 14:27:50 +0100302 inputShape = { 2, 4, 4, 3 };
303 outputShape = { 2, 4, 4, 3 };
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100304 break;
305 case DataLayout::NCHW:
306 default:
Nikhil Rajd1340932018-10-18 14:27:50 +0100307 inputShape = { 2, 3, 4, 4 };
308 outputShape = { 2, 3, 4, 4 };
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100309 break;
310 }
telsoa014fcda012018-03-09 14:13:49 +0000311
telsoa01c577f2c2018-08-31 09:22:23 +0100312 // Checks that outputs and inputs are as we expect them (see definition of CreateBatchNormalizationWorkloadTest).
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100313 CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
314}
315
Sadik Armagan1625efc2021-06-10 18:24:34 +0100316TEST_CASE("CreateBatchNormalizationWithBlobFloat32Workload")
Keith Davisdf04d232020-10-23 17:20:05 +0100317{
318 Graph graph;
319 RefWorkloadFactory factory = GetFactory();
320 auto dataType = armnn::DataType::Float32;
321 auto workload = CreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload,
322 armnn::DataType::Float32>(factory, graph, DataLayout::NHWC);
323
324 TensorShape inputShape;
325 TensorShape outputShape;
326
327 inputShape = { 2, 4, 4, 3 };
328 outputShape = { 2, 4, 4, 3 };
329
330 // Checks that outputs and inputs are as we expect them (see definition of CreateBatchNormalizationWorkloadTest).
331 CheckInputOutput(std::move(workload), TensorInfo(inputShape, dataType), TensorInfo(outputShape, dataType));
332}
333
Sadik Armagan1625efc2021-06-10 18:24:34 +0100334TEST_CASE("CreateBatchNormalizationFloat32Workload")
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100335{
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100336 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload,armnn::DataType::Float32>
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100337 (DataLayout::NCHW);
338}
339
Sadik Armagan1625efc2021-06-10 18:24:34 +0100340TEST_CASE("CreateBatchNormalizationFloat32WorkloadNhwc")
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100341{
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100342 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::Float32>
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100343 (DataLayout::NHWC);
344}
345
Sadik Armagan1625efc2021-06-10 18:24:34 +0100346TEST_CASE("CreateBatchNormalizationFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100347{
348 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload,armnn::DataType::Float16>
349 (DataLayout::NCHW);
350}
351
Sadik Armagan1625efc2021-06-10 18:24:34 +0100352TEST_CASE("CreateBatchNormalizationFloat16WorkloadNhwc")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100353{
354 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::Float16>
355 (DataLayout::NHWC);
356}
357
Sadik Armagan1625efc2021-06-10 18:24:34 +0100358TEST_CASE("CreateBatchNormalizationUint8Workload")
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100359{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000360 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QAsymmU8>
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100361 (DataLayout::NCHW);
362}
363
Sadik Armagan1625efc2021-06-10 18:24:34 +0100364TEST_CASE("CreateBatchNormalizationUint8WorkloadNhwc")
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100365{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000366 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QAsymmU8>
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100367 (DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000368}
369
Sadik Armagan1625efc2021-06-10 18:24:34 +0100370TEST_CASE("CreateBatchNormalizationInt16Workload")
Matteo Martincighf5507132019-06-04 10:59:47 +0100371{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000372 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QSymmS16>
Matteo Martincighf5507132019-06-04 10:59:47 +0100373 (DataLayout::NCHW);
374}
375
Sadik Armagan1625efc2021-06-10 18:24:34 +0100376TEST_CASE("CreateBatchNormalizationInt16WorkloadNhwc")
Matteo Martincighf5507132019-06-04 10:59:47 +0100377{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000378 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QSymmS16>
Matteo Martincighf5507132019-06-04 10:59:47 +0100379 (DataLayout::NHWC);
380}
381
Sadik Armagan1625efc2021-06-10 18:24:34 +0100382TEST_CASE("CreateConvertFp16ToFp32Float32Workload")
telsoa01c577f2c2018-08-31 09:22:23 +0100383{
384 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100385 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100386 auto workload = CreateConvertFp16ToFp32WorkloadTest<RefConvertFp16ToFp32Workload>(factory, graph);
387
388 // Checks that outputs and inputs are as we expect them
389 CheckInputOutput(
390 std::move(workload), TensorInfo({1, 3, 2, 3}, DataType::Float16), TensorInfo({1, 3, 2, 3}, DataType::Float32));
391}
392
Sadik Armagan1625efc2021-06-10 18:24:34 +0100393TEST_CASE("CreateConvertFp32ToFp16Float16Workload")
telsoa01c577f2c2018-08-31 09:22:23 +0100394{
395 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100396 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100397 auto workload = CreateConvertFp32ToFp16WorkloadTest<RefConvertFp32ToFp16Workload>(factory, graph);
398
399 // Checks that outputs and inputs are as we expect them
400 CheckInputOutput(
401 std::move(workload), TensorInfo({1, 3, 2, 3}, DataType::Float32), TensorInfo({1, 3, 2, 3}, DataType::Float16));
402}
403
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100404static void RefCreateConvolution2dWorkloadTest(DataLayout dataLayout = DataLayout::NCHW)
telsoa014fcda012018-03-09 14:13:49 +0000405{
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100406 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100407 RefWorkloadFactory factory = GetFactory();
Mike Kelly9b398322019-05-22 17:21:49 +0100408 auto workload = CreateConvolution2dWorkloadTest<RefConvolution2dWorkload, DataType::Float32>
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100409 (factory, graph, dataLayout);
410
Mike Kellydb482882019-06-14 12:35:24 +0100411 TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({2, 3, 8, 16})
412 : std::initializer_list<unsigned int>({2, 8, 16, 3});
413 TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({2, 2, 2, 10})
414 : std::initializer_list<unsigned int>({2, 2, 10, 2});
telsoa014fcda012018-03-09 14:13:49 +0000415
telsoa01c577f2c2018-08-31 09:22:23 +0100416 // Checks that outputs and inputs are as we expect them (see definition of CreateConvolution2dWorkloadTest).
telsoa014fcda012018-03-09 14:13:49 +0000417 CheckInputOutput(std::move(workload),
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100418 TensorInfo(inputShape, DataType::Float32),
419 TensorInfo(outputShape, DataType::Float32));
420}
421
Sadik Armagan1625efc2021-06-10 18:24:34 +0100422TEST_CASE("CreateConvolution2dFloatNchwWorkload")
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100423{
424 RefCreateConvolution2dWorkloadTest(DataLayout::NCHW);
425}
426
Sadik Armagan1625efc2021-06-10 18:24:34 +0100427TEST_CASE("CreateConvolution2dFloatNhwcWorkload")
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100428{
429 RefCreateConvolution2dWorkloadTest(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000430}
431
Sadik Armagan1625efc2021-06-10 18:24:34 +0100432TEST_CASE("CreateConvolution2dWithBlobWorkload")
Keith Davisdf04d232020-10-23 17:20:05 +0100433{
434 DataLayout dataLayout = DataLayout::NHWC;
435 Graph graph;
436 RefWorkloadFactory factory = GetFactory();
437 auto workload = CreateConvolution2dFusedActivationWithBlobWorkloadTest<RefConvolution2dWorkload, DataType::Float32>
438 (factory, graph, dataLayout);
439
440 TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({2, 3, 8, 16})
441 : std::initializer_list<unsigned int>({2, 8, 16, 3});
442 TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({2, 2, 2, 10})
443 : std::initializer_list<unsigned int>({2, 2, 10, 2});
444
445 // Checks that outputs and inputs are as we expect them (see definition of CreateConvolution2dWorkloadTest).
446 CheckInputOutput(std::move(workload),
447 TensorInfo(inputShape, DataType::Float32),
448 TensorInfo(outputShape, DataType::Float32));
449}
450
Ruomei Yan495852f2019-05-23 11:37:33 +0100451static void RefCreateDepthwiseConvolutionWorkloadTest(DataLayout dataLayout)
452{
453 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100454 RefWorkloadFactory factory = GetFactory();
Ruomei Yan495852f2019-05-23 11:37:33 +0100455 auto workload = CreateDepthwiseConvolution2dWorkloadTest<RefDepthwiseConvolution2dWorkload, DataType::Float32>
456 (factory, graph, dataLayout);
457
Mike Kellydb482882019-06-14 12:35:24 +0100458 TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({ 2, 2, 5, 5 })
459 : std::initializer_list<unsigned int>({ 2, 5, 5, 2 });
460 TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({ 2, 2, 5, 5 })
461 : std::initializer_list<unsigned int>({ 2, 5, 5, 2 });
462
Ruomei Yan495852f2019-05-23 11:37:33 +0100463 // Checks that inputs/outputs are as we expect them (see definition of CreateDepthwiseConvolution2dWorkloadTest).
464 CheckInputOutput(std::move(workload),
465 TensorInfo(inputShape, DataType::Float32),
466 TensorInfo(outputShape, DataType::Float32));
467}
468
Sadik Armagan1625efc2021-06-10 18:24:34 +0100469TEST_CASE("CreateDepthwiseConvolutionFloat32NhwcWorkload")
Ruomei Yan495852f2019-05-23 11:37:33 +0100470{
471 RefCreateDepthwiseConvolutionWorkloadTest(DataLayout::NHWC);
472}
473
Sadik Armagan1625efc2021-06-10 18:24:34 +0100474TEST_CASE("RefCreateFullyConnectedWithBlobWorkloadTest")
Keith Davisdf04d232020-10-23 17:20:05 +0100475{
476 Graph graph;
477 RefWorkloadFactory factory = GetFactory();
478 auto workload = CreateFullyConnectedWithBlobWorkloadTest<RefFullyConnectedWorkload,
479 armnn::DataType::Float32>(factory, graph);
480
481 // Checks that outputs and inputs are as we expect them (see definition of CreateFullyConnectedWorkloadTest).
482 float inputsQScale = 0.0f;
483 float outputQScale = 0.0f;
484 CheckInputOutput(std::move(workload),
485 TensorInfo({ 3, 1, 4, 5 }, armnn::DataType::Float32, inputsQScale),
486 TensorInfo({ 3, 7 }, armnn::DataType::Float32, outputQScale));
487}
488
telsoa01c577f2c2018-08-31 09:22:23 +0100489template <typename FullyConnectedWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +0000490static void RefCreateFullyConnectedWorkloadTest()
491{
492 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100493 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100494 auto workload = CreateFullyConnectedWorkloadTest<FullyConnectedWorkloadType, DataType>(factory, graph);
telsoa014fcda012018-03-09 14:13:49 +0000495
telsoa01c577f2c2018-08-31 09:22:23 +0100496 // Checks that outputs and inputs are as we expect them (see definition of CreateFullyConnectedWorkloadTest).
Derek Lambertif90c56d2020-01-10 17:14:08 +0000497 float inputsQScale = DataType == armnn::DataType::QAsymmU8 ? 1.0f : 0.0;
498 float outputQScale = DataType == armnn::DataType::QAsymmU8 ? 2.0f : 0.0;
telsoa014fcda012018-03-09 14:13:49 +0000499 CheckInputOutput(std::move(workload),
telsoa01c577f2c2018-08-31 09:22:23 +0100500 TensorInfo({ 3, 1, 4, 5 }, DataType, inputsQScale),
501 TensorInfo({ 3, 7 }, DataType, outputQScale));
telsoa014fcda012018-03-09 14:13:49 +0000502}
503
Sadik Armagan1625efc2021-06-10 18:24:34 +0100504TEST_CASE("CreateFullyConnectedWorkloadFloat32")
telsoa014fcda012018-03-09 14:13:49 +0000505{
Francis Murtagh43aec582019-05-27 12:14:10 +0100506 RefCreateFullyConnectedWorkloadTest<RefFullyConnectedWorkload, armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000507}
508
Sadik Armagan1625efc2021-06-10 18:24:34 +0100509TEST_CASE("CreateFullyConnectedWorkloadQuantisedAsymm8")
telsoa014fcda012018-03-09 14:13:49 +0000510{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000511 RefCreateFullyConnectedWorkloadTest<RefFullyConnectedWorkload, armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000512}
513
Sadik Armagan1625efc2021-06-10 18:24:34 +0100514TEST_CASE("CreateFullyConnectedWorkloadQuantisedSymm16")
Francis Murtagh46c09d02019-05-28 08:15:28 +0100515{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000516 RefCreateFullyConnectedWorkloadTest<RefFullyConnectedWorkload, armnn::DataType::QSymmS16>();
Francis Murtagh46c09d02019-05-28 08:15:28 +0100517}
518
narpra0155a97bc2018-10-02 14:35:53 +0100519template <typename NormalizationWorkloadType, armnn::DataType DataType>
Matteo Martincigha160b242018-10-18 10:33:23 +0100520static void RefCreateNormalizationWorkloadTest(DataLayout dataLayout)
telsoa014fcda012018-03-09 14:13:49 +0000521{
narpra0155a97bc2018-10-02 14:35:53 +0100522 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100523 RefWorkloadFactory factory = GetFactory();
Matteo Martincigha160b242018-10-18 10:33:23 +0100524 auto workload = CreateNormalizationWorkloadTest<NormalizationWorkloadType, DataType>(factory, graph, dataLayout);
525
526 TensorShape inputShape;
527 TensorShape outputShape;
528
529 switch (dataLayout)
530 {
531 case DataLayout::NHWC:
532 inputShape = { 3, 1, 5, 5 };
533 outputShape = { 3, 1, 5, 5 };
534 break;
535 case DataLayout::NCHW:
536 default:
537 inputShape = { 3, 5, 5, 1 };
538 outputShape = { 3, 5, 5, 1 };
539 break;
540 }
telsoa014fcda012018-03-09 14:13:49 +0000541
telsoa01c577f2c2018-08-31 09:22:23 +0100542 // Checks that outputs and inputs are as we expect them (see definition of CreateNormalizationWorkloadTest).
Matteo Martincigha160b242018-10-18 10:33:23 +0100543 CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
narpra0155a97bc2018-10-02 14:35:53 +0100544}
545
Sadik Armagan1625efc2021-06-10 18:24:34 +0100546TEST_CASE("CreateRefNormalizationFloat32NchwWorkload")
narpra0155a97bc2018-10-02 14:35:53 +0100547{
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100548 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
Matteo Martincigha160b242018-10-18 10:33:23 +0100549}
550
Sadik Armagan1625efc2021-06-10 18:24:34 +0100551TEST_CASE("CreateRefNormalizationFloat32NhwcWorkload")
Matteo Martincigha160b242018-10-18 10:33:23 +0100552{
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100553 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
554}
555
Sadik Armagan1625efc2021-06-10 18:24:34 +0100556TEST_CASE("CreateRefNormalizationUint8NchwWorkload")
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100557{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000558 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QAsymmU8>(DataLayout::NCHW);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100559}
560
Sadik Armagan1625efc2021-06-10 18:24:34 +0100561TEST_CASE("CreateRefNormalizationUint8NhwcWorkload")
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100562{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000563 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QAsymmU8>(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000564}
565
Sadik Armagan1625efc2021-06-10 18:24:34 +0100566TEST_CASE("CreateRefNormalizationInt16NchwWorkload")
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100567{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000568 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QSymmS16>(DataLayout::NCHW);
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100569}
570
Sadik Armagan1625efc2021-06-10 18:24:34 +0100571TEST_CASE("CreateRefNormalizationInt16NhwcWorkload")
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100572{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000573 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QSymmS16>(DataLayout::NHWC);
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100574}
575
telsoa01c577f2c2018-08-31 09:22:23 +0100576template <typename Pooling2dWorkloadType, armnn::DataType DataType>
James Conroy69482272018-10-19 10:41:35 +0100577static void RefCreatePooling2dWorkloadTest(DataLayout dataLayout)
telsoa014fcda012018-03-09 14:13:49 +0000578{
579 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100580 RefWorkloadFactory factory = GetFactory();
James Conroy69482272018-10-19 10:41:35 +0100581 auto workload = CreatePooling2dWorkloadTest<Pooling2dWorkloadType, DataType>(factory, graph, dataLayout);
582
583 TensorShape inputShape;
584 TensorShape outputShape;
585
586 switch (dataLayout)
587 {
588 case DataLayout::NHWC:
589 inputShape = { 3, 5, 5, 2 };
590 outputShape = { 3, 2, 4, 2 };
591 break;
592 case DataLayout::NCHW:
593 default:
594 inputShape = { 3, 2, 5, 5 };
595 outputShape = { 3, 2, 2, 4 };
596 }
telsoa014fcda012018-03-09 14:13:49 +0000597
telsoa01c577f2c2018-08-31 09:22:23 +0100598 // Checks that outputs and inputs are as we expect them (see definition of CreatePooling2dWorkloadTest).
James Conroy69482272018-10-19 10:41:35 +0100599 CheckInputOutput(std::move(workload),
600 TensorInfo(inputShape, DataType),
601 TensorInfo(outputShape, DataType));
telsoa014fcda012018-03-09 14:13:49 +0000602}
603
Sadik Armagan1625efc2021-06-10 18:24:34 +0100604TEST_CASE("CreatePooling2dFloat32Workload")
telsoa014fcda012018-03-09 14:13:49 +0000605{
Teresa Charlina3b20472019-06-06 11:12:32 +0100606 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
James Conroy69482272018-10-19 10:41:35 +0100607}
608
Sadik Armagan1625efc2021-06-10 18:24:34 +0100609TEST_CASE("CreatePooling2dFloat32NhwcWorkload")
James Conroy69482272018-10-19 10:41:35 +0100610{
Teresa Charlina3b20472019-06-06 11:12:32 +0100611 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000612}
613
Sadik Armagan1625efc2021-06-10 18:24:34 +0100614TEST_CASE("CreatePooling2dUint8Workload")
telsoa014fcda012018-03-09 14:13:49 +0000615{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000616 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QAsymmU8>(DataLayout::NCHW);
James Conroy69482272018-10-19 10:41:35 +0100617}
618
Sadik Armagan1625efc2021-06-10 18:24:34 +0100619TEST_CASE("CreatePooling2dUint8NhwcWorkload")
James Conroy69482272018-10-19 10:41:35 +0100620{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000621 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QAsymmU8>(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000622}
623
Sadik Armagan1625efc2021-06-10 18:24:34 +0100624TEST_CASE("CreatePooling2dInt16Workload")
Teresa Charlin0434df62019-06-06 13:40:35 +0100625{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000626 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QSymmS16>(DataLayout::NCHW);
Teresa Charlin0434df62019-06-06 13:40:35 +0100627}
628
Sadik Armagan1625efc2021-06-10 18:24:34 +0100629TEST_CASE("CreatePooling2dInt16NhwcWorkload")
Teresa Charlin0434df62019-06-06 13:40:35 +0100630{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000631 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QSymmS16>(DataLayout::NHWC);
Teresa Charlin0434df62019-06-06 13:40:35 +0100632}
633
telsoa01c577f2c2018-08-31 09:22:23 +0100634template <typename SoftmaxWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +0000635static void RefCreateSoftmaxWorkloadTest()
636{
637 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100638 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100639 auto workload = CreateSoftmaxWorkloadTest<SoftmaxWorkloadType, DataType>(factory, graph);
telsoa014fcda012018-03-09 14:13:49 +0000640
telsoa01c577f2c2018-08-31 09:22:23 +0100641 // Checks that outputs and inputs are as we expect them (see definition of CreateSoftmaxWorkloadTest).
Sadik Armaganbe88a572020-04-30 11:39:37 +0100642
643 armnn::TensorInfo tensorInfo({4, 1}, DataType);
644 if (DataType == armnn::DataType::QAsymmU8)
645 {
646 tensorInfo.SetQuantizationOffset(0);
647 tensorInfo.SetQuantizationScale(1.f / 256);
648 }
649 else if (DataType == armnn::DataType::QAsymmS8)
650 {
651 tensorInfo.SetQuantizationOffset(-128);
652 tensorInfo.SetQuantizationScale(1.f / 256);
653 }
telsoa014fcda012018-03-09 14:13:49 +0000654 CheckInputOutput(
655 std::move(workload),
Sadik Armaganbe88a572020-04-30 11:39:37 +0100656 tensorInfo,
657 tensorInfo);
telsoa014fcda012018-03-09 14:13:49 +0000658}
659
Sadik Armagan1625efc2021-06-10 18:24:34 +0100660TEST_CASE("CreateSoftmaxFloat32Workload")
telsoa014fcda012018-03-09 14:13:49 +0000661{
nikraj01a121de32019-05-29 10:51:05 +0100662 RefCreateSoftmaxWorkloadTest<RefSoftmaxWorkload, armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000663}
664
Sadik Armagan1625efc2021-06-10 18:24:34 +0100665TEST_CASE("CreateSoftmaxFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100666{
667 RefCreateSoftmaxWorkloadTest<RefSoftmaxWorkload, armnn::DataType::Float16>();
668}
669
Sadik Armagan1625efc2021-06-10 18:24:34 +0100670TEST_CASE("CreateSoftmaxQuantisedAsymm8Workload")
telsoa014fcda012018-03-09 14:13:49 +0000671{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000672 RefCreateSoftmaxWorkloadTest<RefSoftmaxWorkload, armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000673}
674
Sadik Armagan1625efc2021-06-10 18:24:34 +0100675TEST_CASE("CreateSoftmaxQuantisedSymm16Workload")
nikraj01248683f2019-05-29 16:46:50 +0100676{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000677 RefCreateSoftmaxWorkloadTest<RefSoftmaxWorkload, armnn::DataType::QSymmS16>();
nikraj01248683f2019-05-29 16:46:50 +0100678}
679
telsoa01c577f2c2018-08-31 09:22:23 +0100680template <typename SplitterWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +0000681static void RefCreateSplitterWorkloadTest()
682{
683 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100684 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100685 auto workload = CreateSplitterWorkloadTest<SplitterWorkloadType, DataType>(factory, graph);
telsoa014fcda012018-03-09 14:13:49 +0000686
telsoa01c577f2c2018-08-31 09:22:23 +0100687 // Checks that outputs are as we expect them (see definition of CreateSplitterWorkloadTest).
telsoa014fcda012018-03-09 14:13:49 +0000688 SplitterQueueDescriptor queueDescriptor = workload->GetData();
Jan Eilersbb446e52020-04-02 13:56:54 +0100689 auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100690 CHECK((inputHandle->GetTensorInfo() == TensorInfo({ 5, 7, 7 }, DataType)));
surmeh013537c2c2018-05-18 16:31:43 +0100691
Jan Eilersbb446e52020-04-02 13:56:54 +0100692 auto outputHandle0 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100693 CHECK((outputHandle0->GetTensorInfo() == TensorInfo({ 1, 7, 7 }, DataType)));
surmeh013537c2c2018-05-18 16:31:43 +0100694
Jan Eilersbb446e52020-04-02 13:56:54 +0100695 auto outputHandle1 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[1]);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100696 CHECK((outputHandle1->GetTensorInfo() == TensorInfo({ 2, 7, 7 }, DataType)));
surmeh013537c2c2018-05-18 16:31:43 +0100697
Jan Eilersbb446e52020-04-02 13:56:54 +0100698 auto outputHandle2 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[2]);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100699 CHECK((outputHandle2->GetTensorInfo() == TensorInfo({ 2, 7, 7 }, DataType)));
telsoa014fcda012018-03-09 14:13:49 +0000700}
701
Sadik Armagan1625efc2021-06-10 18:24:34 +0100702TEST_CASE("CreateSplitterFloat32Workload")
telsoa014fcda012018-03-09 14:13:49 +0000703{
Ruomei Yan25339c32019-05-28 16:48:20 +0100704 RefCreateSplitterWorkloadTest<RefSplitterWorkload, armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000705}
706
Sadik Armagan1625efc2021-06-10 18:24:34 +0100707TEST_CASE("CreateSplitterFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100708{
709 RefCreateSplitterWorkloadTest<RefSplitterWorkload, armnn::DataType::Float16>();
710}
711
Sadik Armagan1625efc2021-06-10 18:24:34 +0100712TEST_CASE("CreateSplitterUint8Workload")
telsoa014fcda012018-03-09 14:13:49 +0000713{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000714 RefCreateSplitterWorkloadTest<RefSplitterWorkload, armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000715}
716
Jim Flynne242f2d2019-05-22 14:24:13 +0100717template <typename SplitterWorkloadType, typename ConcatWorkloadType, armnn::DataType DataType>
718static void RefCreateSplitterConcatWorkloadTest()
telsoa014fcda012018-03-09 14:13:49 +0000719{
telsoa01c577f2c2018-08-31 09:22:23 +0100720 // Tests that it is possible to decide which output of the splitter layer
Jim Flynne242f2d2019-05-22 14:24:13 +0100721 // should be lined to which input of the concat layer.
telsoa01c577f2c2018-08-31 09:22:23 +0100722 // We tested that is is possible to specify 0th output
Jim Flynne242f2d2019-05-22 14:24:13 +0100723 // of the splitter to be the 1st input to the concat and the 1st output of the splitter to be 0th input
724 // of the concat.
telsoa014fcda012018-03-09 14:13:49 +0000725
726 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100727 RefWorkloadFactory factory = GetFactory();
Jim Flynne242f2d2019-05-22 14:24:13 +0100728 auto workloads = CreateSplitterConcatWorkloadTest<SplitterWorkloadType, ConcatWorkloadType, DataType>
729 (factory, graph);
telsoa014fcda012018-03-09 14:13:49 +0000730
731 auto wlSplitter = std::move(workloads.first);
Jim Flynne242f2d2019-05-22 14:24:13 +0100732 auto wlConcat = std::move(workloads.second);
telsoa014fcda012018-03-09 14:13:49 +0000733
telsoa01c577f2c2018-08-31 09:22:23 +0100734 //Checks that the index of inputs/outputs matches what we declared on InputDescriptor construction.
Matthew Bentham4cefc412019-06-18 16:14:34 +0100735 armnn::RefTensorHandle* sOut0 = dynamic_cast<armnn::RefTensorHandle*>(wlSplitter->GetData().m_Outputs[0]);
736 armnn::RefTensorHandle* sOut1 = dynamic_cast<armnn::RefTensorHandle*>(wlSplitter->GetData().m_Outputs[1]);
737 armnn::RefTensorHandle* mIn0 = dynamic_cast<armnn::RefTensorHandle*>(wlConcat->GetData().m_Inputs[0]);
738 armnn::RefTensorHandle* mIn1 = dynamic_cast<armnn::RefTensorHandle*>(wlConcat->GetData().m_Inputs[1]);
telsoa014fcda012018-03-09 14:13:49 +0000739
Sadik Armagan1625efc2021-06-10 18:24:34 +0100740 CHECK(sOut0);
741 CHECK(sOut1);
742 CHECK(mIn0);
743 CHECK(mIn1);
telsoa014fcda012018-03-09 14:13:49 +0000744
745 bool validDataPointers = (sOut0 == mIn1) && (sOut1 == mIn0);
746
Sadik Armagan1625efc2021-06-10 18:24:34 +0100747 CHECK(validDataPointers);
telsoa014fcda012018-03-09 14:13:49 +0000748}
749
Sadik Armagan1625efc2021-06-10 18:24:34 +0100750TEST_CASE("CreateSplitterConcatFloat32")
telsoa014fcda012018-03-09 14:13:49 +0000751{
Ruomei Yan25339c32019-05-28 16:48:20 +0100752 RefCreateSplitterConcatWorkloadTest<RefSplitterWorkload, RefConcatWorkload, DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000753}
754
Sadik Armagan1625efc2021-06-10 18:24:34 +0100755TEST_CASE("CreateSplitterConcatFloat16")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100756{
757 RefCreateSplitterConcatWorkloadTest<RefSplitterWorkload, RefConcatWorkload, DataType::Float16>();
758}
759
Sadik Armagan1625efc2021-06-10 18:24:34 +0100760TEST_CASE("CreateSplitterConcatUint8")
telsoa014fcda012018-03-09 14:13:49 +0000761{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000762 RefCreateSplitterConcatWorkloadTest<RefSplitterWorkload, RefConcatWorkload, DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000763}
764
telsoa01c577f2c2018-08-31 09:22:23 +0100765template <typename SplitterWorkloadType, typename ActivationWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +0000766static void RefCreateSingleOutputMultipleInputsTest()
767{
telsoa01c577f2c2018-08-31 09:22:23 +0100768 // Tests that it is possible to assign multiple (two) different layers to each of the outputs of a splitter layer.
769 // We created a splitter with two outputs. That each of those outputs is used by two different activation layers.
telsoa014fcda012018-03-09 14:13:49 +0000770
771 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100772 RefWorkloadFactory factory = GetFactory();
telsoa014fcda012018-03-09 14:13:49 +0000773 std::unique_ptr<SplitterWorkloadType> wlSplitter;
774 std::unique_ptr<ActivationWorkloadType> wlActiv0_0;
775 std::unique_ptr<ActivationWorkloadType> wlActiv0_1;
776 std::unique_ptr<ActivationWorkloadType> wlActiv1_0;
777 std::unique_ptr<ActivationWorkloadType> wlActiv1_1;
778
779 CreateSplitterMultipleInputsOneOutputWorkloadTest<SplitterWorkloadType,
telsoa01c577f2c2018-08-31 09:22:23 +0100780 ActivationWorkloadType, DataType>(factory, graph, wlSplitter, wlActiv0_0, wlActiv0_1, wlActiv1_0, wlActiv1_1);
telsoa014fcda012018-03-09 14:13:49 +0000781
Matthew Bentham4cefc412019-06-18 16:14:34 +0100782 armnn::RefTensorHandle* sOut0 = dynamic_cast<armnn::RefTensorHandle*>(wlSplitter->GetData().m_Outputs[0]);
783 armnn::RefTensorHandle* sOut1 = dynamic_cast<armnn::RefTensorHandle*>(wlSplitter->GetData().m_Outputs[1]);
784 armnn::RefTensorHandle* activ0_0Im = dynamic_cast<armnn::RefTensorHandle*>(wlActiv0_0->GetData().m_Inputs[0]);
785 armnn::RefTensorHandle* activ0_1Im = dynamic_cast<armnn::RefTensorHandle*>(wlActiv0_1->GetData().m_Inputs[0]);
786 armnn::RefTensorHandle* activ1_0Im = dynamic_cast<armnn::RefTensorHandle*>(wlActiv1_0->GetData().m_Inputs[0]);
787 armnn::RefTensorHandle* activ1_1Im = dynamic_cast<armnn::RefTensorHandle*>(wlActiv1_1->GetData().m_Inputs[0]);
telsoa014fcda012018-03-09 14:13:49 +0000788
789
Sadik Armagan1625efc2021-06-10 18:24:34 +0100790 CHECK(sOut0);
791 CHECK(sOut1);
792 CHECK(activ0_0Im);
793 CHECK(activ0_1Im);
794 CHECK(activ1_0Im);
795 CHECK(activ1_1Im);
telsoa014fcda012018-03-09 14:13:49 +0000796
797 bool validDataPointers = (sOut0 == activ0_0Im) && (sOut0 == activ0_1Im) &&
798 (sOut1 == activ1_0Im) && (sOut1 == activ1_1Im);
799
Sadik Armagan1625efc2021-06-10 18:24:34 +0100800 CHECK(validDataPointers);
telsoa014fcda012018-03-09 14:13:49 +0000801}
802
Sadik Armagan1625efc2021-06-10 18:24:34 +0100803TEST_CASE("CreateSingleOutputMultipleInputsFloat32")
telsoa014fcda012018-03-09 14:13:49 +0000804{
Ruomei Yan25339c32019-05-28 16:48:20 +0100805 RefCreateSingleOutputMultipleInputsTest<RefSplitterWorkload, RefActivationWorkload,
telsoa01c577f2c2018-08-31 09:22:23 +0100806 armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000807}
808
Sadik Armagan1625efc2021-06-10 18:24:34 +0100809TEST_CASE("CreateSingleOutputMultipleInputsUint8")
telsoa014fcda012018-03-09 14:13:49 +0000810{
Ruomei Yan25339c32019-05-28 16:48:20 +0100811 RefCreateSingleOutputMultipleInputsTest<RefSplitterWorkload, RefActivationWorkload,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000812 armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000813}
814
telsoa01c577f2c2018-08-31 09:22:23 +0100815template <typename ResizeBilinearWorkloadType, armnn::DataType DataType>
James Conroy59540822018-10-11 12:39:05 +0100816static void RefCreateResizeBilinearTest(DataLayout dataLayout)
telsoa014fcda012018-03-09 14:13:49 +0000817{
818 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100819 RefWorkloadFactory factory = GetFactory();
James Conroy59540822018-10-11 12:39:05 +0100820 auto workload = CreateResizeBilinearWorkloadTest<ResizeBilinearWorkloadType, DataType>(factory, graph, dataLayout);
821
822 TensorShape inputShape;
823 TensorShape outputShape;
824
825 switch (dataLayout)
826 {
827 case DataLayout::NHWC:
828 inputShape = { 2, 4, 4, 3 };
829 outputShape = { 2, 2, 2, 3 };
830 break;
James Conroy69482272018-10-19 10:41:35 +0100831 case DataLayout::NCHW:
832 default:
James Conroy59540822018-10-11 12:39:05 +0100833 inputShape = { 2, 3, 4, 4 };
834 outputShape = { 2, 3, 2, 2 };
835 }
telsoa014fcda012018-03-09 14:13:49 +0000836
telsoa01c577f2c2018-08-31 09:22:23 +0100837 // Checks that outputs and inputs are as we expect them (see definition of CreateResizeBilinearWorkloadTest).
James Conroy69482272018-10-19 10:41:35 +0100838 CheckInputOutput(std::move(workload),
839 TensorInfo(inputShape, DataType),
840 TensorInfo(outputShape, DataType));
telsoa014fcda012018-03-09 14:13:49 +0000841}
842
Sadik Armagan1625efc2021-06-10 18:24:34 +0100843TEST_CASE("CreateResizeBilinearFloat32")
telsoa014fcda012018-03-09 14:13:49 +0000844{
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100845 RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
telsoa014fcda012018-03-09 14:13:49 +0000846}
847
Sadik Armagan1625efc2021-06-10 18:24:34 +0100848TEST_CASE("CreateResizeBilinearFloat16")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100849{
850 RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::Float16>(DataLayout::NCHW);
851}
852
Sadik Armagan1625efc2021-06-10 18:24:34 +0100853TEST_CASE("CreateResizeBilinearUint8")
telsoa014fcda012018-03-09 14:13:49 +0000854{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000855 RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::QAsymmU8>(DataLayout::NCHW);
James Conroy59540822018-10-11 12:39:05 +0100856}
857
Sadik Armagan1625efc2021-06-10 18:24:34 +0100858TEST_CASE("CreateResizeBilinearQuantisedAsymm16")
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +0100859{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000860 RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::QSymmS16>(DataLayout::NCHW);
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +0100861}
862
Sadik Armagan1625efc2021-06-10 18:24:34 +0100863TEST_CASE("CreateResizeBilinearFloat32Nhwc")
James Conroy59540822018-10-11 12:39:05 +0100864{
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100865 RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000866}
867
Francis Murtagh57f13d52019-06-24 14:24:36 +0100868template <typename BatchToSpaceNdWorkloadType, armnn::DataType DataType>
869static void RefCreateBatchToSpaceNdTest()
870{
871 Graph graph;
872 RefWorkloadFactory factory;
873
874 auto workload = CreateBatchToSpaceNdWorkloadTest<BatchToSpaceNdWorkloadType, DataType>(factory, graph);
875
876 CheckInputOutput(std::move(workload),
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100877 TensorInfo({ 1, 1, 1, 1 }, DataType),
878 TensorInfo({ 1, 1, 1, 1 }, DataType));
Francis Murtagh57f13d52019-06-24 14:24:36 +0100879}
880
Sadik Armagan1625efc2021-06-10 18:24:34 +0100881TEST_CASE("CreateBatchToSpaceNdFloat32")
Francis Murtagh57f13d52019-06-24 14:24:36 +0100882{
883 RefCreateBatchToSpaceNdTest<RefBatchToSpaceNdWorkload, armnn::DataType::Float32>();
884}
885
Sadik Armagan1625efc2021-06-10 18:24:34 +0100886TEST_CASE("CreateBatchToSpaceNdFloat16")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100887{
888 RefCreateBatchToSpaceNdTest<RefBatchToSpaceNdWorkload, armnn::DataType::Float16>();
889}
890
Sadik Armagan1625efc2021-06-10 18:24:34 +0100891TEST_CASE("CreateBatchToSpaceNdUint8")
Francis Murtagh57f13d52019-06-24 14:24:36 +0100892{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000893 RefCreateBatchToSpaceNdTest<RefBatchToSpaceNdWorkload, armnn::DataType::QAsymmU8>();
Francis Murtagh57f13d52019-06-24 14:24:36 +0100894}
895
Sadik Armagan1625efc2021-06-10 18:24:34 +0100896TEST_CASE("CreateBatchToSpaceNdQSymm16")
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100897{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000898 RefCreateBatchToSpaceNdTest<RefBatchToSpaceNdWorkload, armnn::DataType::QSymmS16>();
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100899}
900
Matteo Martincighb63973e2018-10-16 16:23:33 +0100901template <typename L2NormalizationWorkloadType, armnn::DataType DataType>
902static void RefCreateL2NormalizationTest(DataLayout dataLayout)
telsoa014fcda012018-03-09 14:13:49 +0000903{
904 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100905 RefWorkloadFactory factory = GetFactory();
Matteo Martincighb63973e2018-10-16 16:23:33 +0100906 auto workload =
907 CreateL2NormalizationWorkloadTest<L2NormalizationWorkloadType, DataType>(factory, graph, dataLayout);
908
909 TensorShape inputShape;
910 TensorShape outputShape;
911
912 switch (dataLayout)
913 {
914 case DataLayout::NHWC:
915 inputShape = { 5, 50, 67, 20 };
916 outputShape = { 5, 50, 67, 20 };
917 break;
918 case DataLayout::NCHW:
919 default:
920 inputShape = { 5, 20, 50, 67 };
921 outputShape = { 5, 20, 50, 67 };
922 break;
923 }
telsoa014fcda012018-03-09 14:13:49 +0000924
telsoa01c577f2c2018-08-31 09:22:23 +0100925 // Checks that outputs and inputs are as we expect them (see definition of CreateL2NormalizationWorkloadTest).
Matteo Martincighb63973e2018-10-16 16:23:33 +0100926 CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
927}
928
Sadik Armagan1625efc2021-06-10 18:24:34 +0100929TEST_CASE("CreateL2NormalizationFloat32")
Matteo Martincighb63973e2018-10-16 16:23:33 +0100930{
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100931 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
Matteo Martincighb63973e2018-10-16 16:23:33 +0100932}
933
Sadik Armagan1625efc2021-06-10 18:24:34 +0100934TEST_CASE("CreateL2NormalizationFloat32Nhwc")
Matteo Martincighb63973e2018-10-16 16:23:33 +0100935{
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100936 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
937}
938
Sadik Armagan1625efc2021-06-10 18:24:34 +0100939TEST_CASE("CreateL2NormalizationInt16")
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100940{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000941 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QSymmS16>(DataLayout::NCHW);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100942}
943
Sadik Armagan1625efc2021-06-10 18:24:34 +0100944TEST_CASE("CreateL2NormalizationInt16Nhwc")
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100945{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000946 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QSymmS16>(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000947}
948
Sadik Armagan1625efc2021-06-10 18:24:34 +0100949TEST_CASE("CreateL2NormalizationUint8")
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100950{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000951 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QAsymmU8>(DataLayout::NCHW);
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100952}
953
Sadik Armagan1625efc2021-06-10 18:24:34 +0100954TEST_CASE("CreateL2NormalizationUint8Nhwc")
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100955{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000956 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QAsymmU8>(DataLayout::NHWC);
Ferran Balaguerc6138d82019-06-13 17:23:50 +0100957}
958
telsoa01c577f2c2018-08-31 09:22:23 +0100959template <typename ReshapeWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +0000960static void RefCreateReshapeWorkloadTest()
961{
962 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100963 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100964 auto workload = CreateReshapeWorkloadTest<ReshapeWorkloadType, DataType>(factory, graph);
telsoa014fcda012018-03-09 14:13:49 +0000965
telsoa01c577f2c2018-08-31 09:22:23 +0100966 // Checks that outputs and inputs are as we expect them (see definition of CreateReshapeWorkloadTest).
telsoa014fcda012018-03-09 14:13:49 +0000967 CheckInputOutput(
968 std::move(workload),
telsoa01c577f2c2018-08-31 09:22:23 +0100969 TensorInfo({ 4, 1 }, DataType),
970 TensorInfo({ 1, 4 }, DataType));
telsoa014fcda012018-03-09 14:13:49 +0000971}
972
Sadik Armagan1625efc2021-06-10 18:24:34 +0100973TEST_CASE("CreateReshapeWorkloadFloat32")
telsoa014fcda012018-03-09 14:13:49 +0000974{
Nina Drozd2f2778f2019-05-27 10:37:05 +0100975 RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000976}
977
Sadik Armagan1625efc2021-06-10 18:24:34 +0100978TEST_CASE("CreateReshapeWorkloadQuantisedAsymm8")
telsoa014fcda012018-03-09 14:13:49 +0000979{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000980 RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000981}
982
Sadik Armagan1625efc2021-06-10 18:24:34 +0100983TEST_CASE("CreateReshapeWorkloadQuantisedSymm16")
Nina Drozd8ed4b8c2019-05-29 10:41:04 +0100984{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000985 RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::QSymmS16>();
Nina Drozd8ed4b8c2019-05-29 10:41:04 +0100986}
987
Jim Flynne242f2d2019-05-22 14:24:13 +0100988template <typename ConcatWorkloadType, armnn::DataType DataType>
989static void RefCreateConcatWorkloadTest(const armnn::TensorShape& outputShape,
narpra015cdda352018-11-19 15:30:27 +0000990 unsigned int concatAxis)
991{
992 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100993 RefWorkloadFactory factory = GetFactory();
Jim Flynne242f2d2019-05-22 14:24:13 +0100994 auto workload = CreateConcatWorkloadTest<ConcatWorkloadType, DataType>(factory, graph, outputShape, concatAxis);
narpra015cdda352018-11-19 15:30:27 +0000995
996 CheckInputsOutput(std::move(workload),
997 TensorInfo({ 2, 3, 2, 5 }, DataType),
998 TensorInfo({ 2, 3, 2, 5 }, DataType),
999 TensorInfo(outputShape, DataType));
1000}
1001
Sadik Armagan1625efc2021-06-10 18:24:34 +01001002TEST_CASE("CreateConcatDim0Float32Workload")
narpra015cdda352018-11-19 15:30:27 +00001003{
Jim Flynne242f2d2019-05-22 14:24:13 +01001004 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 4, 3, 2, 5 }, 0);
narpra015cdda352018-11-19 15:30:27 +00001005}
1006
Sadik Armagan1625efc2021-06-10 18:24:34 +01001007TEST_CASE("CreateConcatDim0Float16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +01001008{
1009 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float16>({ 4, 3, 2, 5 }, 0);
1010}
1011
Sadik Armagan1625efc2021-06-10 18:24:34 +01001012TEST_CASE("CreateConcatDim0Uint8Workload")
narpra015cdda352018-11-19 15:30:27 +00001013{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001014 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QAsymmU8>({ 4, 3, 2, 5 }, 0);
Jim Flynncbb66aa2019-05-15 13:03:54 +01001015}
1016
Sadik Armagan1625efc2021-06-10 18:24:34 +01001017TEST_CASE("CreateConcatDim0Uint16Workload")
Jim Flynncbb66aa2019-05-15 13:03:54 +01001018{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001019 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QSymmS16>({ 4, 3, 2, 5 }, 0);
narpra015cdda352018-11-19 15:30:27 +00001020}
1021
Sadik Armagan1625efc2021-06-10 18:24:34 +01001022TEST_CASE("CreateConcatDim1Float32Workload")
narpra015cdda352018-11-19 15:30:27 +00001023{
Jim Flynne242f2d2019-05-22 14:24:13 +01001024 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 6, 2, 5 }, 1);
narpra015cdda352018-11-19 15:30:27 +00001025}
1026
Sadik Armagan1625efc2021-06-10 18:24:34 +01001027TEST_CASE("CreateConcatDim1Uint8Workload")
narpra015cdda352018-11-19 15:30:27 +00001028{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001029 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QAsymmU8>({ 2, 6, 2, 5 }, 1);
narpra015cdda352018-11-19 15:30:27 +00001030}
1031
Sadik Armagan1625efc2021-06-10 18:24:34 +01001032TEST_CASE("CreateConcatDim2Float32Workload")
narpra015cdda352018-11-19 15:30:27 +00001033{
Jim Flynne242f2d2019-05-22 14:24:13 +01001034 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 3, 4, 5 }, 2);
narpra015cdda352018-11-19 15:30:27 +00001035}
1036
Sadik Armagan1625efc2021-06-10 18:24:34 +01001037TEST_CASE("CreateConcatDim2Uint8Workload")
narpra015cdda352018-11-19 15:30:27 +00001038{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001039 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QAsymmU8>({ 2, 3, 4, 5 }, 2);
narpra015cdda352018-11-19 15:30:27 +00001040}
1041
Sadik Armagan1625efc2021-06-10 18:24:34 +01001042TEST_CASE("CreateConcatDim3Float32Workload")
narpra015cdda352018-11-19 15:30:27 +00001043{
Jim Flynne242f2d2019-05-22 14:24:13 +01001044 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 3, 2, 10 }, 3);
narpra015cdda352018-11-19 15:30:27 +00001045}
1046
Sadik Armagan1625efc2021-06-10 18:24:34 +01001047TEST_CASE("CreateConcatDim3Uint8Workload")
narpra015cdda352018-11-19 15:30:27 +00001048{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001049 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QAsymmU8>({ 2, 3, 2, 10 }, 3);
narpra015cdda352018-11-19 15:30:27 +00001050}
1051
Nina Drozd58ef2c62019-05-16 12:09:18 +01001052template <typename ConstantWorkloadType, armnn::DataType DataType>
1053static void RefCreateConstantWorkloadTest(const armnn::TensorShape& outputShape)
1054{
1055 armnn::Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +01001056 RefWorkloadFactory factory = GetFactory();
Nina Drozd58ef2c62019-05-16 12:09:18 +01001057 auto workload = CreateConstantWorkloadTest<ConstantWorkloadType, DataType>(factory, graph, outputShape);
1058
1059 // Check output is as expected
1060 auto queueDescriptor = workload->GetData();
Jan Eilersbb446e52020-04-02 13:56:54 +01001061 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +01001062 CHECK((outputHandle->GetTensorInfo() == TensorInfo(outputShape, DataType)));
Nina Drozd58ef2c62019-05-16 12:09:18 +01001063}
1064
Sadik Armagan1625efc2021-06-10 18:24:34 +01001065TEST_CASE("CreateConstantUint8Workload")
Nina Drozd58ef2c62019-05-16 12:09:18 +01001066{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001067 RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::QAsymmU8>({ 2, 3, 2, 10 });
Nina Drozd58ef2c62019-05-16 12:09:18 +01001068}
1069
Sadik Armagan1625efc2021-06-10 18:24:34 +01001070TEST_CASE("CreateConstantInt16Workload")
Nina Drozd58ef2c62019-05-16 12:09:18 +01001071{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001072 RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::QSymmS16>({ 2, 3, 2, 10 });
Nina Drozd58ef2c62019-05-16 12:09:18 +01001073}
1074
Sadik Armagan1625efc2021-06-10 18:24:34 +01001075TEST_CASE("CreateConstantFloat32Workload")
Nina Drozd58ef2c62019-05-16 12:09:18 +01001076{
1077 RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::Float32>({ 2, 3, 2, 10 });
1078}
1079
Sadik Armagan1625efc2021-06-10 18:24:34 +01001080TEST_CASE("CreateConstantSigned32Workload")
Nina Drozd58ef2c62019-05-16 12:09:18 +01001081{
1082 RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::Signed32>({ 2, 3, 2, 10 });
1083}
1084
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001085static void RefCreatePreluWorkloadTest(const armnn::TensorShape& inputShape,
1086 const armnn::TensorShape& alphaShape,
1087 const armnn::TensorShape& outputShape,
1088 armnn::DataType dataType)
Matteo Martincighab9e5252019-06-13 17:27:46 +01001089{
1090 armnn::Graph graph;
1091 RefWorkloadFactory factory;
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001092 auto workload = CreatePreluWorkloadTest<RefPreluWorkload>(factory,
1093 graph,
1094 inputShape,
1095 alphaShape,
1096 outputShape,
1097 dataType);
Matteo Martincighab9e5252019-06-13 17:27:46 +01001098
1099 // Check output is as expected
1100 auto queueDescriptor = workload->GetData();
Jan Eilersbb446e52020-04-02 13:56:54 +01001101 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +01001102 CHECK((outputHandle->GetTensorInfo() == TensorInfo(outputShape, dataType)));
Matteo Martincighab9e5252019-06-13 17:27:46 +01001103}
1104
Sadik Armagan1625efc2021-06-10 18:24:34 +01001105TEST_CASE("CreatePreluFloat32Workload")
Matteo Martincighab9e5252019-06-13 17:27:46 +01001106{
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001107 RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::Float32);
Matteo Martincighab9e5252019-06-13 17:27:46 +01001108}
1109
Sadik Armagan1625efc2021-06-10 18:24:34 +01001110TEST_CASE("CreatePreluFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +01001111{
1112 RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::Float16);
1113}
1114
Sadik Armagan1625efc2021-06-10 18:24:34 +01001115TEST_CASE("CreatePreluUint8Workload")
Matteo Martincighab9e5252019-06-13 17:27:46 +01001116{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001117 RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::QAsymmU8);
Matteo Martincighab9e5252019-06-13 17:27:46 +01001118}
1119
Sadik Armagan1625efc2021-06-10 18:24:34 +01001120TEST_CASE("CreatePreluInt16Workload")
Matteo Martincighab9e5252019-06-13 17:27:46 +01001121{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001122 RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::QSymmS16);
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001123}
1124
Sadik Armagan1625efc2021-06-10 18:24:34 +01001125TEST_CASE("CreatePreluFloat32NoBroadcastWorkload")
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001126{
Sadik Armagan1625efc2021-06-10 18:24:34 +01001127 CHECK_THROWS_AS(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001128 armnn::DataType::Float32),
1129 armnn::InvalidArgumentException);
1130}
1131
Sadik Armagan1625efc2021-06-10 18:24:34 +01001132TEST_CASE("CreatePreluFloat16NoBroadcastWorkload")
Matthew Jackson9bff1442019-09-12 09:08:23 +01001133{
Sadik Armagan1625efc2021-06-10 18:24:34 +01001134 CHECK_THROWS_AS(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
Matthew Jackson9bff1442019-09-12 09:08:23 +01001135 armnn::DataType::Float16),
1136 armnn::InvalidArgumentException);
1137}
1138
Sadik Armagan1625efc2021-06-10 18:24:34 +01001139TEST_CASE("CreatePreluUint8NoBroadcastWorkload")
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001140{
Sadik Armagan1625efc2021-06-10 18:24:34 +01001141 CHECK_THROWS_AS(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
Derek Lambertif90c56d2020-01-10 17:14:08 +00001142 armnn::DataType::QAsymmU8),
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001143 armnn::InvalidArgumentException);
1144}
1145
Sadik Armagan1625efc2021-06-10 18:24:34 +01001146TEST_CASE("CreatePreluInt16NoBroadcastWorkload")
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001147{
Sadik Armagan1625efc2021-06-10 18:24:34 +01001148 CHECK_THROWS_AS(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
Derek Lambertif90c56d2020-01-10 17:14:08 +00001149 armnn::DataType::QSymmS16),
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001150 armnn::InvalidArgumentException);
Matteo Martincighab9e5252019-06-13 17:27:46 +01001151}
1152
James Conroy60597842019-07-02 10:57:56 +01001153template <typename SpaceToDepthWorkloadType, armnn::DataType DataType>
1154static void RefCreateSpaceToDepthWorkloadTest()
1155{
1156 Graph graph;
1157 RefWorkloadFactory factory;
1158
1159 auto workload = CreateSpaceToDepthWorkloadTest<SpaceToDepthWorkloadType, DataType>(factory, graph);
1160
1161 CheckInputOutput(std::move(workload),
1162 TensorInfo({ 1, 2, 2, 1 }, DataType),
1163 TensorInfo({ 1, 1, 1, 4 }, DataType));
1164}
1165
Sadik Armagan1625efc2021-06-10 18:24:34 +01001166TEST_CASE("CreateSpaceToDepthWorkloadFloat32")
James Conroy60597842019-07-02 10:57:56 +01001167{
1168 RefCreateSpaceToDepthWorkloadTest<RefSpaceToDepthWorkload, armnn::DataType::Float32>();
1169}
1170
Sadik Armagan1625efc2021-06-10 18:24:34 +01001171TEST_CASE("CreateSpaceToDepthWorkloadFloat16")
Matthew Jackson9bff1442019-09-12 09:08:23 +01001172{
1173 RefCreateSpaceToDepthWorkloadTest<RefSpaceToDepthWorkload, armnn::DataType::Float16>();
1174}
1175
Sadik Armagan1625efc2021-06-10 18:24:34 +01001176TEST_CASE("CreateSpaceToDepthWorkloadQASymm8")
James Conroy60597842019-07-02 10:57:56 +01001177{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001178 RefCreateSpaceToDepthWorkloadTest<RefSpaceToDepthWorkload, armnn::DataType::QAsymmU8>();
James Conroy60597842019-07-02 10:57:56 +01001179}
1180
Sadik Armagan1625efc2021-06-10 18:24:34 +01001181TEST_CASE("CreateSpaceToDepthWorkloadQSymm16")
James Conroy60597842019-07-02 10:57:56 +01001182{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001183 RefCreateSpaceToDepthWorkloadTest<RefSpaceToDepthWorkload, armnn::DataType::QSymmS16>();
James Conroy60597842019-07-02 10:57:56 +01001184}
1185
Matthew Jacksond5166102019-07-31 14:06:28 +01001186template <armnn::DataType DataType>
Matthew Jackson81e601c2019-07-11 12:07:09 +01001187static void RefCreateStackWorkloadTest(const armnn::TensorShape& inputShape,
1188 const armnn::TensorShape& outputShape,
1189 unsigned int axis,
Matthew Jacksond5166102019-07-31 14:06:28 +01001190 unsigned int numInputs)
Matthew Jackson81e601c2019-07-11 12:07:09 +01001191{
1192 armnn::Graph graph;
1193 RefWorkloadFactory factory;
Matthew Jacksond5166102019-07-31 14:06:28 +01001194 auto workload = CreateStackWorkloadTest<RefStackWorkload, DataType>(factory,
1195 graph,
1196 inputShape,
1197 outputShape,
1198 axis,
1199 numInputs);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001200
Matthew Jacksond5166102019-07-31 14:06:28 +01001201 // Check inputs and output are as expected
1202 StackQueueDescriptor queueDescriptor = workload->GetData();
1203 for (unsigned int i = 0; i < numInputs; ++i)
1204 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001205 auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[i]);
Sadik Armagan1625efc2021-06-10 18:24:34 +01001206 CHECK((inputHandle->GetTensorInfo() == TensorInfo(inputShape, DataType)));
Matthew Jacksond5166102019-07-31 14:06:28 +01001207 }
Jan Eilersbb446e52020-04-02 13:56:54 +01001208 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +01001209 CHECK((outputHandle->GetTensorInfo() == TensorInfo(outputShape, DataType)));
Matthew Jackson81e601c2019-07-11 12:07:09 +01001210}
1211
Sadik Armagan1625efc2021-06-10 18:24:34 +01001212TEST_CASE("CreateStackFloat32Workload")
Matthew Jackson81e601c2019-07-11 12:07:09 +01001213{
Matthew Jacksond5166102019-07-31 14:06:28 +01001214 RefCreateStackWorkloadTest<armnn::DataType::Float32>({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001215}
1216
Sadik Armagan1625efc2021-06-10 18:24:34 +01001217TEST_CASE("CreateStackUint8Workload")
Matthew Jackson81e601c2019-07-11 12:07:09 +01001218{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001219 RefCreateStackWorkloadTest<armnn::DataType::QAsymmU8>({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001220}
1221
Sadik Armagan1625efc2021-06-10 18:24:34 +01001222TEST_CASE("CreateStackUint16Workload")
Matthew Jackson81e601c2019-07-11 12:07:09 +01001223{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001224 RefCreateStackWorkloadTest<armnn::DataType::QSymmS16>({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001225}
1226
James Conroy4f1f8992020-04-29 20:01:10 +01001227template <typename QLstmWorkloadType>
1228static void RefCreateQLstmWorkloadTest()
1229{
1230 Graph graph;
1231 RefWorkloadFactory factory;
1232
1233 auto workload = CreateQLstmWorkloadTest<QLstmWorkloadType>(factory, graph);
1234
1235 armnn::TensorInfo inputInfo({2 , 4}, armnn::DataType::QAsymmS8, 0.0078125f, 0);
1236
1237 armnn::TensorInfo cellStateInfo({2 , 4}, armnn::DataType::QSymmS16, 3.05176e-05f, 0);
1238
1239 armnn::TensorInfo outputInfo({2 , 4}, armnn::DataType::QAsymmS8, 0.007f, 0);
1240
1241 QLstmQueueDescriptor queueDescriptor = workload->GetData();
Jan Eilersaaf9a8f2020-07-01 16:35:35 +01001242 auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
1243 auto cellStateOutHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[1]);
1244 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[2]);
James Conroy4f1f8992020-04-29 20:01:10 +01001245
Sadik Armagan1625efc2021-06-10 18:24:34 +01001246 CHECK((inputHandle->GetTensorInfo() == inputInfo));
1247 CHECK((cellStateOutHandle->GetTensorInfo() == cellStateInfo));
1248 CHECK((outputHandle->GetTensorInfo() == outputInfo));
James Conroy4f1f8992020-04-29 20:01:10 +01001249}
1250
Sadik Armagan1625efc2021-06-10 18:24:34 +01001251TEST_CASE("CreateQLstmWorkload")
James Conroy4f1f8992020-04-29 20:01:10 +01001252{
1253 RefCreateQLstmWorkloadTest<RefQLstmWorkload>();
1254}
1255
Sadik Armagan1625efc2021-06-10 18:24:34 +01001256}