blob: 3bba0b73938996926e71d635f78717ce2de4c54a [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
Mike Kelly52e90bf2023-03-15 15:06:23 +00002// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
arovir0143095f32018-10-09 18:04:24 +01005
Sadik Armagana097d2a2021-11-24 15:47:28 +00006#include <CreateWorkload.hpp>
arovir0143095f32018-10-09 18:04:24 +01007
Jan Eilersbb446e52020-04-02 13:56:54 +01008#include <armnn/utility/PolymorphicDowncast.hpp>
Matthew Bentham4cefc412019-06-18 16:14:34 +01009#include <reference/RefTensorHandle.hpp>
Teresa Charlin788e2a62022-01-17 21:19:52 +000010#include <reference/RefTensorHandleFactory.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000011#include <reference/RefWorkloadFactory.hpp>
12#include <reference/workloads/RefWorkloads.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Sadik Armagan1625efc2021-06-10 18:24:34 +010014#include <doctest/doctest.h>
15
telsoa014fcda012018-03-09 14:13:49 +000016namespace
17{
18
19template<typename Workload>
20void CheckInputOutput(std::unique_ptr<Workload> workload, const TensorInfo& inputInfo, const TensorInfo& outputInfo)
21{
22 auto queueDescriptor = workload->GetData();
Jan Eilersbb446e52020-04-02 13:56:54 +010023 auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
24 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +010025 CHECK((inputHandle->GetTensorInfo() == inputInfo));
26 CHECK((outputHandle->GetTensorInfo() == outputInfo));
telsoa014fcda012018-03-09 14:13:49 +000027}
28
29template <typename Workload>
30void CheckInputsOutput(std::unique_ptr<Workload> workload,
31 const TensorInfo& inputInfo0,
32 const TensorInfo& inputInfo1,
33 const TensorInfo& outputInfo)
34{
35 auto queueDescriptor = workload->GetData();
Jan Eilersbb446e52020-04-02 13:56:54 +010036 auto inputHandle0 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
37 auto inputHandle1 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[1]);
38 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +010039 CHECK((inputHandle0->GetTensorInfo() == inputInfo0));
40 CHECK((inputHandle1->GetTensorInfo() == inputInfo1));
41 CHECK((outputHandle->GetTensorInfo() == outputInfo));
telsoa014fcda012018-03-09 14:13:49 +000042}
Matthew Bentham7c1603a2019-06-21 17:22:23 +010043
44armnn::RefWorkloadFactory GetFactory()
45{
46 std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
47 return RefWorkloadFactory(memoryManager);
48}
49
telsoa014fcda012018-03-09 14:13:49 +000050}
51
Sadik Armagan1625efc2021-06-10 18:24:34 +010052TEST_SUITE("CreateWorkloadRef")
53{
telsoa01c577f2c2018-08-31 09:22:23 +010054template <typename ActivationWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +000055static void RefCreateActivationWorkloadTest()
56{
57 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +010058 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +010059 auto workload = CreateActivationWorkloadTest<ActivationWorkloadType, DataType>(factory, graph);
telsoa014fcda012018-03-09 14:13:49 +000060
telsoa01c577f2c2018-08-31 09:22:23 +010061 // Checks that outputs are as we expect them (see definition of CreateActivationWorkloadTest).
telsoa014fcda012018-03-09 14:13:49 +000062 CheckInputOutput(std::move(workload),
telsoa01c577f2c2018-08-31 09:22:23 +010063 TensorInfo({ 1, 1 }, DataType),
64 TensorInfo({ 1, 1 }, DataType));
telsoa014fcda012018-03-09 14:13:49 +000065}
66
Sadik Armagan1625efc2021-06-10 18:24:34 +010067TEST_CASE("CreateActivationFloat32Workload")
telsoa014fcda012018-03-09 14:13:49 +000068{
Nattapat Chaimanowongae2c5f02019-04-24 16:19:57 +010069 RefCreateActivationWorkloadTest<RefActivationWorkload, armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +000070}
71
Sadik Armagan1625efc2021-06-10 18:24:34 +010072TEST_CASE("CreateActivationUint8Workload")
telsoa014fcda012018-03-09 14:13:49 +000073{
Derek Lambertif90c56d2020-01-10 17:14:08 +000074 RefCreateActivationWorkloadTest<RefActivationWorkload, armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +000075}
76
David Beckbc392452018-09-10 14:47:28 +010077template <typename WorkloadType,
78 typename DescriptorType,
79 typename LayerType,
80 armnn::DataType DataType>
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000081static void RefCreateElementwiseWorkloadTest()
telsoa014fcda012018-03-09 14:13:49 +000082{
83 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +010084 RefWorkloadFactory factory = GetFactory();
Éanna Ó Catháind57415d2018-11-28 16:24:38 +000085 auto workload = CreateElementwiseWorkloadTest<WorkloadType, DescriptorType, LayerType, DataType>(
86 factory, graph);
telsoa014fcda012018-03-09 14:13:49 +000087
telsoa014fcda012018-03-09 14:13:49 +000088 CheckInputsOutput(std::move(workload),
telsoa01c577f2c2018-08-31 09:22:23 +010089 TensorInfo({ 2, 3 }, DataType),
90 TensorInfo({ 2, 3 }, DataType),
91 TensorInfo({ 2, 3 }, DataType));
telsoa014fcda012018-03-09 14:13:49 +000092}
93
Sadik Armagan1625efc2021-06-10 18:24:34 +010094TEST_CASE("CreateSubtractionWorkloadWithBlobTest")
Keith Davisdf04d232020-10-23 17:20:05 +010095{
96 Graph graph;
97 RefWorkloadFactory factory = GetFactory();
98 armnn::DataType DataType = armnn::DataType::Float32;
99
100 auto workload = CreateSubtractionWithBlobWorkloadTest<RefSubtractionWorkload<>,
101 SubtractionQueueDescriptor,
102 armnn::DataType::Float32>
103 (factory, graph);
104
105 CheckInputsOutput(std::move(workload),
106 TensorInfo({ 2, 3 }, DataType),
107 TensorInfo({ 2, 3 }, DataType),
108 TensorInfo({ 2, 3 }, DataType));
109}
110
Sadik Armagan1625efc2021-06-10 18:24:34 +0100111TEST_CASE("CreateAdditionWorkloadWithBlobTest")
Keith Davisdf04d232020-10-23 17:20:05 +0100112{
113 Graph graph;
114 RefWorkloadFactory factory = GetFactory();
115 armnn::DataType DataType = armnn::DataType::Float32;
116
117 auto workload = CreateAdditionWithBlobWorkloadTest<RefAdditionWorkload<>,
118 AdditionQueueDescriptor,
119 armnn::DataType::Float32>(factory, graph);
120
121 CheckInputsOutput(std::move(workload),
122 TensorInfo({ 2, 3 }, DataType),
123 TensorInfo({ 2, 3 }, DataType),
124 TensorInfo({ 2, 3 }, DataType));
125}
126
Sadik Armagan1625efc2021-06-10 18:24:34 +0100127TEST_CASE("CreateMultiplicationWorkloadWithBlobTest")
Keith Davisdf04d232020-10-23 17:20:05 +0100128{
129 Graph graph;
130 RefWorkloadFactory factory = GetFactory();
131 armnn::DataType DataType = armnn::DataType::Float32;
132
133 auto workload = CreateMultiplicationWithBlobWorkloadTest<RefMultiplicationWorkload<>,
134 MultiplicationQueueDescriptor,
135 armnn::DataType::Float32>(factory, graph);
136
137 CheckInputsOutput(std::move(workload),
138 TensorInfo({2, 3}, DataType),
139 TensorInfo({2, 3}, DataType),
140 TensorInfo({2, 3}, DataType));
141}
142
Sadik Armagan1625efc2021-06-10 18:24:34 +0100143TEST_CASE("CreateAdditionFloatWorkload")
telsoa014fcda012018-03-09 14:13:49 +0000144{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000145 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100146 RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000147 AdditionQueueDescriptor,
148 AdditionLayer,
149 armnn::DataType::Float32>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000150 ARMNN_NO_DEPRECATE_WARN_END
telsoa014fcda012018-03-09 14:13:49 +0000151}
152
Sadik Armagan1625efc2021-06-10 18:24:34 +0100153TEST_CASE("CreateAdditionUint8Workload")
telsoa014fcda012018-03-09 14:13:49 +0000154{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000155 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100156 RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000157 AdditionQueueDescriptor,
158 AdditionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000159 armnn::DataType::QAsymmU8>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000160 ARMNN_NO_DEPRECATE_WARN_END
David Beckbc392452018-09-10 14:47:28 +0100161}
162
Sadik Armagan1625efc2021-06-10 18:24:34 +0100163TEST_CASE("CreateAdditionInt16Workload")
Sadik Armagan2999a022019-04-09 14:20:12 +0100164{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000165 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100166 RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
Sadik Armagan2999a022019-04-09 14:20:12 +0100167 AdditionQueueDescriptor,
168 AdditionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000169 armnn::DataType::QSymmS16>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000170 ARMNN_NO_DEPRECATE_WARN_END
Sadik Armagan2999a022019-04-09 14:20:12 +0100171}
172
Sadik Armagan1625efc2021-06-10 18:24:34 +0100173TEST_CASE("CreateAdditionInt32Workload")
Finn Williamscbd2c232020-06-22 15:58:32 +0100174{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000175 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100176 RefCreateElementwiseWorkloadTest<RefAdditionWorkload<int32_t>,
177 AdditionQueueDescriptor,
178 AdditionLayer,
179 armnn::DataType::Signed32>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000180 ARMNN_NO_DEPRECATE_WARN_END
Finn Williamscbd2c232020-06-22 15:58:32 +0100181}
182
Sadik Armagan1625efc2021-06-10 18:24:34 +0100183TEST_CASE("CreateSubtractionFloat32Workload")
David Beckbc392452018-09-10 14:47:28 +0100184{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000185 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100186 RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000187 SubtractionQueueDescriptor,
188 SubtractionLayer,
189 armnn::DataType::Float32>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000190 ARMNN_NO_DEPRECATE_WARN_END
David Beckbc392452018-09-10 14:47:28 +0100191}
192
Sadik Armagan1625efc2021-06-10 18:24:34 +0100193TEST_CASE("CreateSubtractionFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100194{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000195 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100196 RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100197 SubtractionQueueDescriptor,
198 SubtractionLayer,
199 armnn::DataType::Float16>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000200 ARMNN_NO_DEPRECATE_WARN_END
Matthew Jackson9bff1442019-09-12 09:08:23 +0100201}
202
Sadik Armagan1625efc2021-06-10 18:24:34 +0100203TEST_CASE("CreateSubtractionUint8Workload")
David Beckbc392452018-09-10 14:47:28 +0100204{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000205 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100206 RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000207 SubtractionQueueDescriptor,
208 SubtractionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000209 armnn::DataType::QAsymmU8>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000210 ARMNN_NO_DEPRECATE_WARN_END
David Beckbc392452018-09-10 14:47:28 +0100211}
212
Sadik Armagan1625efc2021-06-10 18:24:34 +0100213TEST_CASE("CreateSubtractionInt16Workload")
Sadik Armagan2999a022019-04-09 14:20:12 +0100214{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000215 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100216 RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
Sadik Armagan2999a022019-04-09 14:20:12 +0100217 SubtractionQueueDescriptor,
218 SubtractionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000219 armnn::DataType::QSymmS16>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000220 ARMNN_NO_DEPRECATE_WARN_END
Sadik Armagan2999a022019-04-09 14:20:12 +0100221}
222
Sadik Armagan1625efc2021-06-10 18:24:34 +0100223TEST_CASE("CreateSubtractionInt32Workload")
Finn Williamscbd2c232020-06-22 15:58:32 +0100224{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000225 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100226 RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<int32_t>,
227 SubtractionQueueDescriptor,
228 SubtractionLayer,
229 armnn::DataType::Signed32>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000230 ARMNN_NO_DEPRECATE_WARN_END
Finn Williamscbd2c232020-06-22 15:58:32 +0100231}
232
Sadik Armagan1625efc2021-06-10 18:24:34 +0100233TEST_CASE("CreateMultiplicationFloatWorkload")
David Beckbc392452018-09-10 14:47:28 +0100234{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000235 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100236 RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000237 MultiplicationQueueDescriptor,
238 MultiplicationLayer,
239 armnn::DataType::Float32>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000240 ARMNN_NO_DEPRECATE_WARN_END
David Beckbc392452018-09-10 14:47:28 +0100241}
242
Sadik Armagan1625efc2021-06-10 18:24:34 +0100243TEST_CASE("CreateMultiplicationUint8Workload")
David Beckbc392452018-09-10 14:47:28 +0100244{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000245 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100246 RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000247 MultiplicationQueueDescriptor,
248 MultiplicationLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000249 armnn::DataType::QAsymmU8>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000250 ARMNN_NO_DEPRECATE_WARN_END
David Beckbc392452018-09-10 14:47:28 +0100251}
252
Sadik Armagan1625efc2021-06-10 18:24:34 +0100253TEST_CASE("CreateMultiplicationInt16Workload")
Sadik Armagan2999a022019-04-09 14:20:12 +0100254{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000255 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100256 RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
Sadik Armagan2999a022019-04-09 14:20:12 +0100257 MultiplicationQueueDescriptor,
258 MultiplicationLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000259 armnn::DataType::QSymmS16>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000260 ARMNN_NO_DEPRECATE_WARN_END
Sadik Armagan2999a022019-04-09 14:20:12 +0100261}
262
Sadik Armagan1625efc2021-06-10 18:24:34 +0100263TEST_CASE("CreateMultiplicationInt32Workload")
Finn Williamscbd2c232020-06-22 15:58:32 +0100264{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000265 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100266 RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<int32_t>,
267 MultiplicationQueueDescriptor,
268 MultiplicationLayer,
269 armnn::DataType::Signed32>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000270 ARMNN_NO_DEPRECATE_WARN_END
Finn Williamscbd2c232020-06-22 15:58:32 +0100271}
272
Sadik Armagan1625efc2021-06-10 18:24:34 +0100273TEST_CASE("CreateDivisionFloat32Workload")
David Beckbc392452018-09-10 14:47:28 +0100274{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000275 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100276 RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000277 DivisionQueueDescriptor,
278 DivisionLayer,
279 armnn::DataType::Float32>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000280 ARMNN_NO_DEPRECATE_WARN_END
David Beckbc392452018-09-10 14:47:28 +0100281}
282
Sadik Armagan1625efc2021-06-10 18:24:34 +0100283TEST_CASE("CreateDivisionFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100284{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000285 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100286 RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100287 DivisionQueueDescriptor,
288 DivisionLayer,
289 armnn::DataType::Float16>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000290 ARMNN_NO_DEPRECATE_WARN_END
Matthew Jackson9bff1442019-09-12 09:08:23 +0100291}
292
Sadik Armagan1625efc2021-06-10 18:24:34 +0100293TEST_CASE("CreateDivisionUint8Workload")
David Beckbc392452018-09-10 14:47:28 +0100294{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000295 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100296 RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
Éanna Ó Catháind57415d2018-11-28 16:24:38 +0000297 DivisionQueueDescriptor,
298 DivisionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000299 armnn::DataType::QAsymmU8>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000300 ARMNN_NO_DEPRECATE_WARN_END
telsoa014fcda012018-03-09 14:13:49 +0000301}
302
Sadik Armagan1625efc2021-06-10 18:24:34 +0100303TEST_CASE("CreateDivisionInt16Workload")
Sadik Armagan2999a022019-04-09 14:20:12 +0100304{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000305 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100306 RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
Sadik Armagan2999a022019-04-09 14:20:12 +0100307 DivisionQueueDescriptor,
308 DivisionLayer,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000309 armnn::DataType::QSymmS16>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000310 ARMNN_NO_DEPRECATE_WARN_END
Sadik Armagan2999a022019-04-09 14:20:12 +0100311}
312
Sadik Armagan1625efc2021-06-10 18:24:34 +0100313TEST_CASE("CreateDivisionInt32Workload")
Finn Williamscbd2c232020-06-22 15:58:32 +0100314{
Mike Kelly52e90bf2023-03-15 15:06:23 +0000315 ARMNN_NO_DEPRECATE_WARN_BEGIN
Finn Williamscbd2c232020-06-22 15:58:32 +0100316 RefCreateElementwiseWorkloadTest<RefDivisionWorkload<int32_t>,
317 DivisionQueueDescriptor,
318 DivisionLayer,
319 armnn::DataType::Signed32>();
Mike Kelly52e90bf2023-03-15 15:06:23 +0000320 ARMNN_NO_DEPRECATE_WARN_END
Finn Williamscbd2c232020-06-22 15:58:32 +0100321}
322
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100323template <typename BatchNormalizationWorkloadType, armnn::DataType DataType>
324static void RefCreateBatchNormalizationWorkloadTest(DataLayout dataLayout)
telsoa014fcda012018-03-09 14:13:49 +0000325{
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100326 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100327 RefWorkloadFactory factory = GetFactory();
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100328 auto workload = CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType>(factory,
329 graph,
330 dataLayout);
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100331
332 TensorShape inputShape;
333 TensorShape outputShape;
334
335 switch (dataLayout)
336 {
337 case DataLayout::NHWC:
Nikhil Rajd1340932018-10-18 14:27:50 +0100338 inputShape = { 2, 4, 4, 3 };
339 outputShape = { 2, 4, 4, 3 };
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100340 break;
341 case DataLayout::NCHW:
342 default:
Nikhil Rajd1340932018-10-18 14:27:50 +0100343 inputShape = { 2, 3, 4, 4 };
344 outputShape = { 2, 3, 4, 4 };
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100345 break;
346 }
telsoa014fcda012018-03-09 14:13:49 +0000347
telsoa01c577f2c2018-08-31 09:22:23 +0100348 // Checks that outputs and inputs are as we expect them (see definition of CreateBatchNormalizationWorkloadTest).
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100349 CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
350}
351
Sadik Armagan1625efc2021-06-10 18:24:34 +0100352TEST_CASE("CreateBatchNormalizationWithBlobFloat32Workload")
Keith Davisdf04d232020-10-23 17:20:05 +0100353{
354 Graph graph;
355 RefWorkloadFactory factory = GetFactory();
356 auto dataType = armnn::DataType::Float32;
357 auto workload = CreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload,
358 armnn::DataType::Float32>(factory, graph, DataLayout::NHWC);
359
360 TensorShape inputShape;
361 TensorShape outputShape;
362
363 inputShape = { 2, 4, 4, 3 };
364 outputShape = { 2, 4, 4, 3 };
365
366 // Checks that outputs and inputs are as we expect them (see definition of CreateBatchNormalizationWorkloadTest).
367 CheckInputOutput(std::move(workload), TensorInfo(inputShape, dataType), TensorInfo(outputShape, dataType));
368}
369
Sadik Armagan1625efc2021-06-10 18:24:34 +0100370TEST_CASE("CreateBatchNormalizationFloat32Workload")
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100371{
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100372 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload,armnn::DataType::Float32>
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100373 (DataLayout::NCHW);
374}
375
Sadik Armagan1625efc2021-06-10 18:24:34 +0100376TEST_CASE("CreateBatchNormalizationFloat32WorkloadNhwc")
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100377{
Matteo Martincigh3122bd52019-06-03 16:54:25 +0100378 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::Float32>
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100379 (DataLayout::NHWC);
380}
381
Sadik Armagan1625efc2021-06-10 18:24:34 +0100382TEST_CASE("CreateBatchNormalizationFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100383{
384 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload,armnn::DataType::Float16>
385 (DataLayout::NCHW);
386}
387
Sadik Armagan1625efc2021-06-10 18:24:34 +0100388TEST_CASE("CreateBatchNormalizationFloat16WorkloadNhwc")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100389{
390 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::Float16>
391 (DataLayout::NHWC);
392}
393
Sadik Armagan1625efc2021-06-10 18:24:34 +0100394TEST_CASE("CreateBatchNormalizationUint8Workload")
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100395{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000396 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QAsymmU8>
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100397 (DataLayout::NCHW);
398}
399
Sadik Armagan1625efc2021-06-10 18:24:34 +0100400TEST_CASE("CreateBatchNormalizationUint8WorkloadNhwc")
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100401{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000402 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QAsymmU8>
Matteo Martincigh3dc43032018-10-18 10:55:19 +0100403 (DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000404}
405
Sadik Armagan1625efc2021-06-10 18:24:34 +0100406TEST_CASE("CreateBatchNormalizationInt16Workload")
Matteo Martincighf5507132019-06-04 10:59:47 +0100407{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000408 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QSymmS16>
Matteo Martincighf5507132019-06-04 10:59:47 +0100409 (DataLayout::NCHW);
410}
411
Sadik Armagan1625efc2021-06-10 18:24:34 +0100412TEST_CASE("CreateBatchNormalizationInt16WorkloadNhwc")
Matteo Martincighf5507132019-06-04 10:59:47 +0100413{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000414 RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationWorkload, armnn::DataType::QSymmS16>
Matteo Martincighf5507132019-06-04 10:59:47 +0100415 (DataLayout::NHWC);
416}
417
Sadik Armagan1625efc2021-06-10 18:24:34 +0100418TEST_CASE("CreateConvertFp16ToFp32Float32Workload")
telsoa01c577f2c2018-08-31 09:22:23 +0100419{
420 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100421 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100422 auto workload = CreateConvertFp16ToFp32WorkloadTest<RefConvertFp16ToFp32Workload>(factory, graph);
423
424 // Checks that outputs and inputs are as we expect them
425 CheckInputOutput(
426 std::move(workload), TensorInfo({1, 3, 2, 3}, DataType::Float16), TensorInfo({1, 3, 2, 3}, DataType::Float32));
427}
428
Sadik Armagan1625efc2021-06-10 18:24:34 +0100429TEST_CASE("CreateConvertFp32ToFp16Float16Workload")
telsoa01c577f2c2018-08-31 09:22:23 +0100430{
431 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100432 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100433 auto workload = CreateConvertFp32ToFp16WorkloadTest<RefConvertFp32ToFp16Workload>(factory, graph);
434
435 // Checks that outputs and inputs are as we expect them
436 CheckInputOutput(
437 std::move(workload), TensorInfo({1, 3, 2, 3}, DataType::Float32), TensorInfo({1, 3, 2, 3}, DataType::Float16));
438}
439
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100440static void RefCreateConvolution2dWorkloadTest(DataLayout dataLayout = DataLayout::NCHW)
telsoa014fcda012018-03-09 14:13:49 +0000441{
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100442 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100443 RefWorkloadFactory factory = GetFactory();
Mike Kelly9b398322019-05-22 17:21:49 +0100444 auto workload = CreateConvolution2dWorkloadTest<RefConvolution2dWorkload, DataType::Float32>
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100445 (factory, graph, dataLayout);
446
Mike Kellydb482882019-06-14 12:35:24 +0100447 TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({2, 3, 8, 16})
448 : std::initializer_list<unsigned int>({2, 8, 16, 3});
449 TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({2, 2, 2, 10})
450 : std::initializer_list<unsigned int>({2, 2, 10, 2});
telsoa014fcda012018-03-09 14:13:49 +0000451
telsoa01c577f2c2018-08-31 09:22:23 +0100452 // Checks that outputs and inputs are as we expect them (see definition of CreateConvolution2dWorkloadTest).
telsoa014fcda012018-03-09 14:13:49 +0000453 CheckInputOutput(std::move(workload),
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100454 TensorInfo(inputShape, DataType::Float32),
455 TensorInfo(outputShape, DataType::Float32));
456}
457
Sadik Armagan1625efc2021-06-10 18:24:34 +0100458TEST_CASE("CreateConvolution2dFloatNchwWorkload")
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100459{
460 RefCreateConvolution2dWorkloadTest(DataLayout::NCHW);
461}
462
Sadik Armagan1625efc2021-06-10 18:24:34 +0100463TEST_CASE("CreateConvolution2dFloatNhwcWorkload")
Nikhil Raje4dfd6e2018-10-18 10:11:04 +0100464{
465 RefCreateConvolution2dWorkloadTest(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000466}
467
Sadik Armagan1625efc2021-06-10 18:24:34 +0100468TEST_CASE("CreateConvolution2dWithBlobWorkload")
Keith Davisdf04d232020-10-23 17:20:05 +0100469{
470 DataLayout dataLayout = DataLayout::NHWC;
471 Graph graph;
472 RefWorkloadFactory factory = GetFactory();
473 auto workload = CreateConvolution2dFusedActivationWithBlobWorkloadTest<RefConvolution2dWorkload, DataType::Float32>
474 (factory, graph, dataLayout);
475
476 TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({2, 3, 8, 16})
477 : std::initializer_list<unsigned int>({2, 8, 16, 3});
478 TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({2, 2, 2, 10})
479 : std::initializer_list<unsigned int>({2, 2, 10, 2});
480
481 // Checks that outputs and inputs are as we expect them (see definition of CreateConvolution2dWorkloadTest).
482 CheckInputOutput(std::move(workload),
483 TensorInfo(inputShape, DataType::Float32),
484 TensorInfo(outputShape, DataType::Float32));
485}
486
Ruomei Yan495852f2019-05-23 11:37:33 +0100487static void RefCreateDepthwiseConvolutionWorkloadTest(DataLayout dataLayout)
488{
489 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100490 RefWorkloadFactory factory = GetFactory();
Ruomei Yan495852f2019-05-23 11:37:33 +0100491 auto workload = CreateDepthwiseConvolution2dWorkloadTest<RefDepthwiseConvolution2dWorkload, DataType::Float32>
492 (factory, graph, dataLayout);
493
Mike Kellydb482882019-06-14 12:35:24 +0100494 TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({ 2, 2, 5, 5 })
495 : std::initializer_list<unsigned int>({ 2, 5, 5, 2 });
496 TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? std::initializer_list<unsigned int>({ 2, 2, 5, 5 })
497 : std::initializer_list<unsigned int>({ 2, 5, 5, 2 });
498
Ruomei Yan495852f2019-05-23 11:37:33 +0100499 // Checks that inputs/outputs are as we expect them (see definition of CreateDepthwiseConvolution2dWorkloadTest).
500 CheckInputOutput(std::move(workload),
501 TensorInfo(inputShape, DataType::Float32),
502 TensorInfo(outputShape, DataType::Float32));
503}
504
Sadik Armagan1625efc2021-06-10 18:24:34 +0100505TEST_CASE("CreateDepthwiseConvolutionFloat32NhwcWorkload")
Ruomei Yan495852f2019-05-23 11:37:33 +0100506{
507 RefCreateDepthwiseConvolutionWorkloadTest(DataLayout::NHWC);
508}
509
Sadik Armagan1625efc2021-06-10 18:24:34 +0100510TEST_CASE("RefCreateFullyConnectedWithBlobWorkloadTest")
Keith Davisdf04d232020-10-23 17:20:05 +0100511{
512 Graph graph;
513 RefWorkloadFactory factory = GetFactory();
514 auto workload = CreateFullyConnectedWithBlobWorkloadTest<RefFullyConnectedWorkload,
515 armnn::DataType::Float32>(factory, graph);
516
517 // Checks that outputs and inputs are as we expect them (see definition of CreateFullyConnectedWorkloadTest).
518 float inputsQScale = 0.0f;
519 float outputQScale = 0.0f;
520 CheckInputOutput(std::move(workload),
521 TensorInfo({ 3, 1, 4, 5 }, armnn::DataType::Float32, inputsQScale),
522 TensorInfo({ 3, 7 }, armnn::DataType::Float32, outputQScale));
523}
524
Matthew Sloyan81beae32021-07-13 19:46:11 +0100525TEST_CASE("CreateFullyConnectedWorkloadWeightsBiasesAsInputsFloat32")
526{
527 Graph graph;
528 RefWorkloadFactory factory = GetFactory();
529
530 auto workload =
531 CreateFullyConnectedWorkloadWeightsBiasesAsInputsTest<RefFullyConnectedWorkload,
532 armnn::DataType::Float32>(factory, graph);
533
534 // Checks that outputs and inputs are as we expect them (see definition of CreateFullyConnectedWorkloadTest).
535 float inputsQScale = 0.0f;
536 float outputQScale = 0.0f;
537 CheckInputsOutput(std::move(workload),
538 TensorInfo({ 3, 1, 4, 5 }, armnn::DataType::Float32, inputsQScale),
539 TensorInfo({ 7, 20 }, armnn::DataType::Float32, inputsQScale),
540 TensorInfo({ 3, 7 }, armnn::DataType::Float32, outputQScale));
541}
542
telsoa01c577f2c2018-08-31 09:22:23 +0100543template <typename FullyConnectedWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +0000544static void RefCreateFullyConnectedWorkloadTest()
545{
546 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100547 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100548 auto workload = CreateFullyConnectedWorkloadTest<FullyConnectedWorkloadType, DataType>(factory, graph);
telsoa014fcda012018-03-09 14:13:49 +0000549
telsoa01c577f2c2018-08-31 09:22:23 +0100550 // Checks that outputs and inputs are as we expect them (see definition of CreateFullyConnectedWorkloadTest).
Derek Lambertif90c56d2020-01-10 17:14:08 +0000551 float inputsQScale = DataType == armnn::DataType::QAsymmU8 ? 1.0f : 0.0;
552 float outputQScale = DataType == armnn::DataType::QAsymmU8 ? 2.0f : 0.0;
telsoa014fcda012018-03-09 14:13:49 +0000553 CheckInputOutput(std::move(workload),
telsoa01c577f2c2018-08-31 09:22:23 +0100554 TensorInfo({ 3, 1, 4, 5 }, DataType, inputsQScale),
555 TensorInfo({ 3, 7 }, DataType, outputQScale));
telsoa014fcda012018-03-09 14:13:49 +0000556}
557
Sadik Armagan1625efc2021-06-10 18:24:34 +0100558TEST_CASE("CreateFullyConnectedWorkloadFloat32")
telsoa014fcda012018-03-09 14:13:49 +0000559{
Francis Murtagh43aec582019-05-27 12:14:10 +0100560 RefCreateFullyConnectedWorkloadTest<RefFullyConnectedWorkload, armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000561}
562
Sadik Armagan1625efc2021-06-10 18:24:34 +0100563TEST_CASE("CreateFullyConnectedWorkloadQuantisedAsymm8")
telsoa014fcda012018-03-09 14:13:49 +0000564{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000565 RefCreateFullyConnectedWorkloadTest<RefFullyConnectedWorkload, armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000566}
567
Sadik Armagan1625efc2021-06-10 18:24:34 +0100568TEST_CASE("CreateFullyConnectedWorkloadQuantisedSymm16")
Francis Murtagh46c09d02019-05-28 08:15:28 +0100569{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000570 RefCreateFullyConnectedWorkloadTest<RefFullyConnectedWorkload, armnn::DataType::QSymmS16>();
Francis Murtagh46c09d02019-05-28 08:15:28 +0100571}
572
narpra0155a97bc2018-10-02 14:35:53 +0100573template <typename NormalizationWorkloadType, armnn::DataType DataType>
Matteo Martincigha160b242018-10-18 10:33:23 +0100574static void RefCreateNormalizationWorkloadTest(DataLayout dataLayout)
telsoa014fcda012018-03-09 14:13:49 +0000575{
narpra0155a97bc2018-10-02 14:35:53 +0100576 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100577 RefWorkloadFactory factory = GetFactory();
Matteo Martincigha160b242018-10-18 10:33:23 +0100578 auto workload = CreateNormalizationWorkloadTest<NormalizationWorkloadType, DataType>(factory, graph, dataLayout);
579
580 TensorShape inputShape;
581 TensorShape outputShape;
582
583 switch (dataLayout)
584 {
585 case DataLayout::NHWC:
586 inputShape = { 3, 1, 5, 5 };
587 outputShape = { 3, 1, 5, 5 };
588 break;
589 case DataLayout::NCHW:
590 default:
591 inputShape = { 3, 5, 5, 1 };
592 outputShape = { 3, 5, 5, 1 };
593 break;
594 }
telsoa014fcda012018-03-09 14:13:49 +0000595
telsoa01c577f2c2018-08-31 09:22:23 +0100596 // Checks that outputs and inputs are as we expect them (see definition of CreateNormalizationWorkloadTest).
Matteo Martincigha160b242018-10-18 10:33:23 +0100597 CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
narpra0155a97bc2018-10-02 14:35:53 +0100598}
599
Sadik Armagan1625efc2021-06-10 18:24:34 +0100600TEST_CASE("CreateRefNormalizationFloat32NchwWorkload")
narpra0155a97bc2018-10-02 14:35:53 +0100601{
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100602 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
Matteo Martincigha160b242018-10-18 10:33:23 +0100603}
604
Sadik Armagan1625efc2021-06-10 18:24:34 +0100605TEST_CASE("CreateRefNormalizationFloat32NhwcWorkload")
Matteo Martincigha160b242018-10-18 10:33:23 +0100606{
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100607 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
608}
609
Sadik Armagan1625efc2021-06-10 18:24:34 +0100610TEST_CASE("CreateRefNormalizationUint8NchwWorkload")
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100611{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000612 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QAsymmU8>(DataLayout::NCHW);
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100613}
614
Sadik Armagan1625efc2021-06-10 18:24:34 +0100615TEST_CASE("CreateRefNormalizationUint8NhwcWorkload")
Matteo Martincigh2fc70c52019-06-05 14:12:48 +0100616{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000617 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QAsymmU8>(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000618}
619
Sadik Armagan1625efc2021-06-10 18:24:34 +0100620TEST_CASE("CreateRefNormalizationInt16NchwWorkload")
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100621{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000622 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QSymmS16>(DataLayout::NCHW);
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100623}
624
Sadik Armagan1625efc2021-06-10 18:24:34 +0100625TEST_CASE("CreateRefNormalizationInt16NhwcWorkload")
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100626{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000627 RefCreateNormalizationWorkloadTest<RefNormalizationWorkload, armnn::DataType::QSymmS16>(DataLayout::NHWC);
Matteo Martincigh6aeb7712019-06-05 17:23:29 +0100628}
629
telsoa01c577f2c2018-08-31 09:22:23 +0100630template <typename Pooling2dWorkloadType, armnn::DataType DataType>
James Conroy69482272018-10-19 10:41:35 +0100631static void RefCreatePooling2dWorkloadTest(DataLayout dataLayout)
telsoa014fcda012018-03-09 14:13:49 +0000632{
633 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100634 RefWorkloadFactory factory = GetFactory();
James Conroy69482272018-10-19 10:41:35 +0100635 auto workload = CreatePooling2dWorkloadTest<Pooling2dWorkloadType, DataType>(factory, graph, dataLayout);
636
637 TensorShape inputShape;
638 TensorShape outputShape;
639
640 switch (dataLayout)
641 {
642 case DataLayout::NHWC:
643 inputShape = { 3, 5, 5, 2 };
644 outputShape = { 3, 2, 4, 2 };
645 break;
646 case DataLayout::NCHW:
647 default:
648 inputShape = { 3, 2, 5, 5 };
649 outputShape = { 3, 2, 2, 4 };
650 }
telsoa014fcda012018-03-09 14:13:49 +0000651
telsoa01c577f2c2018-08-31 09:22:23 +0100652 // Checks that outputs and inputs are as we expect them (see definition of CreatePooling2dWorkloadTest).
James Conroy69482272018-10-19 10:41:35 +0100653 CheckInputOutput(std::move(workload),
654 TensorInfo(inputShape, DataType),
655 TensorInfo(outputShape, DataType));
telsoa014fcda012018-03-09 14:13:49 +0000656}
657
Sadik Armagan1625efc2021-06-10 18:24:34 +0100658TEST_CASE("CreatePooling2dFloat32Workload")
telsoa014fcda012018-03-09 14:13:49 +0000659{
Teresa Charlina3b20472019-06-06 11:12:32 +0100660 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
James Conroy69482272018-10-19 10:41:35 +0100661}
662
Sadik Armagan1625efc2021-06-10 18:24:34 +0100663TEST_CASE("CreatePooling2dFloat32NhwcWorkload")
James Conroy69482272018-10-19 10:41:35 +0100664{
Teresa Charlina3b20472019-06-06 11:12:32 +0100665 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000666}
667
Sadik Armagan1625efc2021-06-10 18:24:34 +0100668TEST_CASE("CreatePooling2dUint8Workload")
telsoa014fcda012018-03-09 14:13:49 +0000669{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000670 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QAsymmU8>(DataLayout::NCHW);
James Conroy69482272018-10-19 10:41:35 +0100671}
672
Sadik Armagan1625efc2021-06-10 18:24:34 +0100673TEST_CASE("CreatePooling2dUint8NhwcWorkload")
James Conroy69482272018-10-19 10:41:35 +0100674{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000675 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QAsymmU8>(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000676}
677
Sadik Armagan1625efc2021-06-10 18:24:34 +0100678TEST_CASE("CreatePooling2dInt16Workload")
Teresa Charlin0434df62019-06-06 13:40:35 +0100679{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000680 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QSymmS16>(DataLayout::NCHW);
Teresa Charlin0434df62019-06-06 13:40:35 +0100681}
682
Sadik Armagan1625efc2021-06-10 18:24:34 +0100683TEST_CASE("CreatePooling2dInt16NhwcWorkload")
Teresa Charlin0434df62019-06-06 13:40:35 +0100684{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000685 RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QSymmS16>(DataLayout::NHWC);
Teresa Charlin0434df62019-06-06 13:40:35 +0100686}
687
telsoa01c577f2c2018-08-31 09:22:23 +0100688template <typename SoftmaxWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +0000689static void RefCreateSoftmaxWorkloadTest()
690{
691 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100692 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100693 auto workload = CreateSoftmaxWorkloadTest<SoftmaxWorkloadType, DataType>(factory, graph);
telsoa014fcda012018-03-09 14:13:49 +0000694
telsoa01c577f2c2018-08-31 09:22:23 +0100695 // Checks that outputs and inputs are as we expect them (see definition of CreateSoftmaxWorkloadTest).
Sadik Armaganbe88a572020-04-30 11:39:37 +0100696
697 armnn::TensorInfo tensorInfo({4, 1}, DataType);
698 if (DataType == armnn::DataType::QAsymmU8)
699 {
700 tensorInfo.SetQuantizationOffset(0);
701 tensorInfo.SetQuantizationScale(1.f / 256);
702 }
703 else if (DataType == armnn::DataType::QAsymmS8)
704 {
705 tensorInfo.SetQuantizationOffset(-128);
706 tensorInfo.SetQuantizationScale(1.f / 256);
707 }
telsoa014fcda012018-03-09 14:13:49 +0000708 CheckInputOutput(
709 std::move(workload),
Sadik Armaganbe88a572020-04-30 11:39:37 +0100710 tensorInfo,
711 tensorInfo);
telsoa014fcda012018-03-09 14:13:49 +0000712}
713
Sadik Armagan1625efc2021-06-10 18:24:34 +0100714TEST_CASE("CreateSoftmaxFloat32Workload")
telsoa014fcda012018-03-09 14:13:49 +0000715{
nikraj01a121de32019-05-29 10:51:05 +0100716 RefCreateSoftmaxWorkloadTest<RefSoftmaxWorkload, armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000717}
718
Sadik Armagan1625efc2021-06-10 18:24:34 +0100719TEST_CASE("CreateSoftmaxFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100720{
721 RefCreateSoftmaxWorkloadTest<RefSoftmaxWorkload, armnn::DataType::Float16>();
722}
723
Sadik Armagan1625efc2021-06-10 18:24:34 +0100724TEST_CASE("CreateSoftmaxQuantisedAsymm8Workload")
telsoa014fcda012018-03-09 14:13:49 +0000725{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000726 RefCreateSoftmaxWorkloadTest<RefSoftmaxWorkload, armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000727}
728
Sadik Armagan1625efc2021-06-10 18:24:34 +0100729TEST_CASE("CreateSoftmaxQuantisedSymm16Workload")
nikraj01248683f2019-05-29 16:46:50 +0100730{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000731 RefCreateSoftmaxWorkloadTest<RefSoftmaxWorkload, armnn::DataType::QSymmS16>();
nikraj01248683f2019-05-29 16:46:50 +0100732}
733
telsoa01c577f2c2018-08-31 09:22:23 +0100734template <typename SplitterWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +0000735static void RefCreateSplitterWorkloadTest()
736{
737 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100738 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +0100739 auto workload = CreateSplitterWorkloadTest<SplitterWorkloadType, DataType>(factory, graph);
telsoa014fcda012018-03-09 14:13:49 +0000740
telsoa01c577f2c2018-08-31 09:22:23 +0100741 // Checks that outputs are as we expect them (see definition of CreateSplitterWorkloadTest).
telsoa014fcda012018-03-09 14:13:49 +0000742 SplitterQueueDescriptor queueDescriptor = workload->GetData();
Jan Eilersbb446e52020-04-02 13:56:54 +0100743 auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100744 CHECK((inputHandle->GetTensorInfo() == TensorInfo({ 5, 7, 7 }, DataType)));
surmeh013537c2c2018-05-18 16:31:43 +0100745
Jan Eilersbb446e52020-04-02 13:56:54 +0100746 auto outputHandle0 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100747 CHECK((outputHandle0->GetTensorInfo() == TensorInfo({ 1, 7, 7 }, DataType)));
surmeh013537c2c2018-05-18 16:31:43 +0100748
Jan Eilersbb446e52020-04-02 13:56:54 +0100749 auto outputHandle1 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[1]);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100750 CHECK((outputHandle1->GetTensorInfo() == TensorInfo({ 2, 7, 7 }, DataType)));
surmeh013537c2c2018-05-18 16:31:43 +0100751
Jan Eilersbb446e52020-04-02 13:56:54 +0100752 auto outputHandle2 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[2]);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100753 CHECK((outputHandle2->GetTensorInfo() == TensorInfo({ 2, 7, 7 }, DataType)));
telsoa014fcda012018-03-09 14:13:49 +0000754}
755
Sadik Armagan1625efc2021-06-10 18:24:34 +0100756TEST_CASE("CreateSplitterFloat32Workload")
telsoa014fcda012018-03-09 14:13:49 +0000757{
Ruomei Yan25339c32019-05-28 16:48:20 +0100758 RefCreateSplitterWorkloadTest<RefSplitterWorkload, armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000759}
760
Sadik Armagan1625efc2021-06-10 18:24:34 +0100761TEST_CASE("CreateSplitterFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100762{
763 RefCreateSplitterWorkloadTest<RefSplitterWorkload, armnn::DataType::Float16>();
764}
765
Sadik Armagan1625efc2021-06-10 18:24:34 +0100766TEST_CASE("CreateSplitterUint8Workload")
telsoa014fcda012018-03-09 14:13:49 +0000767{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000768 RefCreateSplitterWorkloadTest<RefSplitterWorkload, armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000769}
770
Jim Flynne242f2d2019-05-22 14:24:13 +0100771template <typename SplitterWorkloadType, typename ConcatWorkloadType, armnn::DataType DataType>
772static void RefCreateSplitterConcatWorkloadTest()
telsoa014fcda012018-03-09 14:13:49 +0000773{
telsoa01c577f2c2018-08-31 09:22:23 +0100774 // Tests that it is possible to decide which output of the splitter layer
Jim Flynne242f2d2019-05-22 14:24:13 +0100775 // should be lined to which input of the concat layer.
telsoa01c577f2c2018-08-31 09:22:23 +0100776 // We tested that is is possible to specify 0th output
Jim Flynne242f2d2019-05-22 14:24:13 +0100777 // of the splitter to be the 1st input to the concat and the 1st output of the splitter to be 0th input
778 // of the concat.
telsoa014fcda012018-03-09 14:13:49 +0000779
780 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100781 RefWorkloadFactory factory = GetFactory();
Jim Flynne242f2d2019-05-22 14:24:13 +0100782 auto workloads = CreateSplitterConcatWorkloadTest<SplitterWorkloadType, ConcatWorkloadType, DataType>
783 (factory, graph);
telsoa014fcda012018-03-09 14:13:49 +0000784
785 auto wlSplitter = std::move(workloads.first);
Jim Flynne242f2d2019-05-22 14:24:13 +0100786 auto wlConcat = std::move(workloads.second);
telsoa014fcda012018-03-09 14:13:49 +0000787
telsoa01c577f2c2018-08-31 09:22:23 +0100788 //Checks that the index of inputs/outputs matches what we declared on InputDescriptor construction.
Matthew Bentham4cefc412019-06-18 16:14:34 +0100789 armnn::RefTensorHandle* sOut0 = dynamic_cast<armnn::RefTensorHandle*>(wlSplitter->GetData().m_Outputs[0]);
790 armnn::RefTensorHandle* sOut1 = dynamic_cast<armnn::RefTensorHandle*>(wlSplitter->GetData().m_Outputs[1]);
791 armnn::RefTensorHandle* mIn0 = dynamic_cast<armnn::RefTensorHandle*>(wlConcat->GetData().m_Inputs[0]);
792 armnn::RefTensorHandle* mIn1 = dynamic_cast<armnn::RefTensorHandle*>(wlConcat->GetData().m_Inputs[1]);
telsoa014fcda012018-03-09 14:13:49 +0000793
Sadik Armagan1625efc2021-06-10 18:24:34 +0100794 CHECK(sOut0);
795 CHECK(sOut1);
796 CHECK(mIn0);
797 CHECK(mIn1);
telsoa014fcda012018-03-09 14:13:49 +0000798
799 bool validDataPointers = (sOut0 == mIn1) && (sOut1 == mIn0);
800
Sadik Armagan1625efc2021-06-10 18:24:34 +0100801 CHECK(validDataPointers);
telsoa014fcda012018-03-09 14:13:49 +0000802}
803
Sadik Armagan1625efc2021-06-10 18:24:34 +0100804TEST_CASE("CreateSplitterConcatFloat32")
telsoa014fcda012018-03-09 14:13:49 +0000805{
Ruomei Yan25339c32019-05-28 16:48:20 +0100806 RefCreateSplitterConcatWorkloadTest<RefSplitterWorkload, RefConcatWorkload, DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000807}
808
Sadik Armagan1625efc2021-06-10 18:24:34 +0100809TEST_CASE("CreateSplitterConcatFloat16")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100810{
811 RefCreateSplitterConcatWorkloadTest<RefSplitterWorkload, RefConcatWorkload, DataType::Float16>();
812}
813
Sadik Armagan1625efc2021-06-10 18:24:34 +0100814TEST_CASE("CreateSplitterConcatUint8")
telsoa014fcda012018-03-09 14:13:49 +0000815{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000816 RefCreateSplitterConcatWorkloadTest<RefSplitterWorkload, RefConcatWorkload, DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000817}
818
telsoa01c577f2c2018-08-31 09:22:23 +0100819template <typename SplitterWorkloadType, typename ActivationWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +0000820static void RefCreateSingleOutputMultipleInputsTest()
821{
telsoa01c577f2c2018-08-31 09:22:23 +0100822 // Tests that it is possible to assign multiple (two) different layers to each of the outputs of a splitter layer.
823 // We created a splitter with two outputs. That each of those outputs is used by two different activation layers.
telsoa014fcda012018-03-09 14:13:49 +0000824
825 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100826 RefWorkloadFactory factory = GetFactory();
telsoa014fcda012018-03-09 14:13:49 +0000827 std::unique_ptr<SplitterWorkloadType> wlSplitter;
828 std::unique_ptr<ActivationWorkloadType> wlActiv0_0;
829 std::unique_ptr<ActivationWorkloadType> wlActiv0_1;
830 std::unique_ptr<ActivationWorkloadType> wlActiv1_0;
831 std::unique_ptr<ActivationWorkloadType> wlActiv1_1;
832
833 CreateSplitterMultipleInputsOneOutputWorkloadTest<SplitterWorkloadType,
telsoa01c577f2c2018-08-31 09:22:23 +0100834 ActivationWorkloadType, DataType>(factory, graph, wlSplitter, wlActiv0_0, wlActiv0_1, wlActiv1_0, wlActiv1_1);
telsoa014fcda012018-03-09 14:13:49 +0000835
Matthew Bentham4cefc412019-06-18 16:14:34 +0100836 armnn::RefTensorHandle* sOut0 = dynamic_cast<armnn::RefTensorHandle*>(wlSplitter->GetData().m_Outputs[0]);
837 armnn::RefTensorHandle* sOut1 = dynamic_cast<armnn::RefTensorHandle*>(wlSplitter->GetData().m_Outputs[1]);
838 armnn::RefTensorHandle* activ0_0Im = dynamic_cast<armnn::RefTensorHandle*>(wlActiv0_0->GetData().m_Inputs[0]);
839 armnn::RefTensorHandle* activ0_1Im = dynamic_cast<armnn::RefTensorHandle*>(wlActiv0_1->GetData().m_Inputs[0]);
840 armnn::RefTensorHandle* activ1_0Im = dynamic_cast<armnn::RefTensorHandle*>(wlActiv1_0->GetData().m_Inputs[0]);
841 armnn::RefTensorHandle* activ1_1Im = dynamic_cast<armnn::RefTensorHandle*>(wlActiv1_1->GetData().m_Inputs[0]);
telsoa014fcda012018-03-09 14:13:49 +0000842
843
Sadik Armagan1625efc2021-06-10 18:24:34 +0100844 CHECK(sOut0);
845 CHECK(sOut1);
846 CHECK(activ0_0Im);
847 CHECK(activ0_1Im);
848 CHECK(activ1_0Im);
849 CHECK(activ1_1Im);
telsoa014fcda012018-03-09 14:13:49 +0000850
851 bool validDataPointers = (sOut0 == activ0_0Im) && (sOut0 == activ0_1Im) &&
852 (sOut1 == activ1_0Im) && (sOut1 == activ1_1Im);
853
Sadik Armagan1625efc2021-06-10 18:24:34 +0100854 CHECK(validDataPointers);
telsoa014fcda012018-03-09 14:13:49 +0000855}
856
Sadik Armagan1625efc2021-06-10 18:24:34 +0100857TEST_CASE("CreateSingleOutputMultipleInputsFloat32")
telsoa014fcda012018-03-09 14:13:49 +0000858{
Ruomei Yan25339c32019-05-28 16:48:20 +0100859 RefCreateSingleOutputMultipleInputsTest<RefSplitterWorkload, RefActivationWorkload,
telsoa01c577f2c2018-08-31 09:22:23 +0100860 armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +0000861}
862
Sadik Armagan1625efc2021-06-10 18:24:34 +0100863TEST_CASE("CreateSingleOutputMultipleInputsUint8")
telsoa014fcda012018-03-09 14:13:49 +0000864{
Ruomei Yan25339c32019-05-28 16:48:20 +0100865 RefCreateSingleOutputMultipleInputsTest<RefSplitterWorkload, RefActivationWorkload,
Derek Lambertif90c56d2020-01-10 17:14:08 +0000866 armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +0000867}
868
telsoa01c577f2c2018-08-31 09:22:23 +0100869template <typename ResizeBilinearWorkloadType, armnn::DataType DataType>
James Conroy59540822018-10-11 12:39:05 +0100870static void RefCreateResizeBilinearTest(DataLayout dataLayout)
telsoa014fcda012018-03-09 14:13:49 +0000871{
872 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100873 RefWorkloadFactory factory = GetFactory();
James Conroy59540822018-10-11 12:39:05 +0100874 auto workload = CreateResizeBilinearWorkloadTest<ResizeBilinearWorkloadType, DataType>(factory, graph, dataLayout);
875
876 TensorShape inputShape;
877 TensorShape outputShape;
878
879 switch (dataLayout)
880 {
881 case DataLayout::NHWC:
882 inputShape = { 2, 4, 4, 3 };
883 outputShape = { 2, 2, 2, 3 };
884 break;
James Conroy69482272018-10-19 10:41:35 +0100885 case DataLayout::NCHW:
886 default:
James Conroy59540822018-10-11 12:39:05 +0100887 inputShape = { 2, 3, 4, 4 };
888 outputShape = { 2, 3, 2, 2 };
889 }
telsoa014fcda012018-03-09 14:13:49 +0000890
telsoa01c577f2c2018-08-31 09:22:23 +0100891 // Checks that outputs and inputs are as we expect them (see definition of CreateResizeBilinearWorkloadTest).
James Conroy69482272018-10-19 10:41:35 +0100892 CheckInputOutput(std::move(workload),
893 TensorInfo(inputShape, DataType),
894 TensorInfo(outputShape, DataType));
telsoa014fcda012018-03-09 14:13:49 +0000895}
896
Sadik Armagan1625efc2021-06-10 18:24:34 +0100897TEST_CASE("CreateResizeBilinearFloat32")
telsoa014fcda012018-03-09 14:13:49 +0000898{
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100899 RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
telsoa014fcda012018-03-09 14:13:49 +0000900}
901
Sadik Armagan1625efc2021-06-10 18:24:34 +0100902TEST_CASE("CreateResizeBilinearFloat16")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100903{
904 RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::Float16>(DataLayout::NCHW);
905}
906
Sadik Armagan1625efc2021-06-10 18:24:34 +0100907TEST_CASE("CreateResizeBilinearUint8")
telsoa014fcda012018-03-09 14:13:49 +0000908{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000909 RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::QAsymmU8>(DataLayout::NCHW);
James Conroy59540822018-10-11 12:39:05 +0100910}
911
Sadik Armagan1625efc2021-06-10 18:24:34 +0100912TEST_CASE("CreateResizeBilinearQuantisedAsymm16")
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +0100913{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000914 RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::QSymmS16>(DataLayout::NCHW);
Ellen Norris-Thompson3cb85f32019-06-17 11:32:49 +0100915}
916
Sadik Armagan1625efc2021-06-10 18:24:34 +0100917TEST_CASE("CreateResizeBilinearFloat32Nhwc")
James Conroy59540822018-10-11 12:39:05 +0100918{
Aron Virginas-Tar169d2f12019-07-01 19:01:44 +0100919 RefCreateResizeBilinearTest<RefResizeWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +0000920}
921
Francis Murtagh57f13d52019-06-24 14:24:36 +0100922template <typename BatchToSpaceNdWorkloadType, armnn::DataType DataType>
923static void RefCreateBatchToSpaceNdTest()
924{
925 Graph graph;
926 RefWorkloadFactory factory;
927
928 auto workload = CreateBatchToSpaceNdWorkloadTest<BatchToSpaceNdWorkloadType, DataType>(factory, graph);
929
930 CheckInputOutput(std::move(workload),
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100931 TensorInfo({ 1, 1, 1, 1 }, DataType),
932 TensorInfo({ 1, 1, 1, 1 }, DataType));
Francis Murtagh57f13d52019-06-24 14:24:36 +0100933}
934
Sadik Armagan1625efc2021-06-10 18:24:34 +0100935TEST_CASE("CreateBatchToSpaceNdFloat32")
Francis Murtagh57f13d52019-06-24 14:24:36 +0100936{
937 RefCreateBatchToSpaceNdTest<RefBatchToSpaceNdWorkload, armnn::DataType::Float32>();
938}
939
Sadik Armagan1625efc2021-06-10 18:24:34 +0100940TEST_CASE("CreateBatchToSpaceNdFloat16")
Matthew Jackson9bff1442019-09-12 09:08:23 +0100941{
942 RefCreateBatchToSpaceNdTest<RefBatchToSpaceNdWorkload, armnn::DataType::Float16>();
943}
944
Sadik Armagan1625efc2021-06-10 18:24:34 +0100945TEST_CASE("CreateBatchToSpaceNdUint8")
Francis Murtagh57f13d52019-06-24 14:24:36 +0100946{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000947 RefCreateBatchToSpaceNdTest<RefBatchToSpaceNdWorkload, armnn::DataType::QAsymmU8>();
Francis Murtagh57f13d52019-06-24 14:24:36 +0100948}
949
Sadik Armagan1625efc2021-06-10 18:24:34 +0100950TEST_CASE("CreateBatchToSpaceNdQSymm16")
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100951{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000952 RefCreateBatchToSpaceNdTest<RefBatchToSpaceNdWorkload, armnn::DataType::QSymmS16>();
Francis Murtaghd0dfe172019-06-25 10:57:10 +0100953}
954
Matteo Martincighb63973e2018-10-16 16:23:33 +0100955template <typename L2NormalizationWorkloadType, armnn::DataType DataType>
956static void RefCreateL2NormalizationTest(DataLayout dataLayout)
telsoa014fcda012018-03-09 14:13:49 +0000957{
958 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100959 RefWorkloadFactory factory = GetFactory();
Matteo Martincighb63973e2018-10-16 16:23:33 +0100960 auto workload =
961 CreateL2NormalizationWorkloadTest<L2NormalizationWorkloadType, DataType>(factory, graph, dataLayout);
962
963 TensorShape inputShape;
964 TensorShape outputShape;
965
966 switch (dataLayout)
967 {
968 case DataLayout::NHWC:
969 inputShape = { 5, 50, 67, 20 };
970 outputShape = { 5, 50, 67, 20 };
971 break;
972 case DataLayout::NCHW:
973 default:
974 inputShape = { 5, 20, 50, 67 };
975 outputShape = { 5, 20, 50, 67 };
976 break;
977 }
telsoa014fcda012018-03-09 14:13:49 +0000978
telsoa01c577f2c2018-08-31 09:22:23 +0100979 // Checks that outputs and inputs are as we expect them (see definition of CreateL2NormalizationWorkloadTest).
Matteo Martincighb63973e2018-10-16 16:23:33 +0100980 CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
981}
982
Sadik Armagan1625efc2021-06-10 18:24:34 +0100983TEST_CASE("CreateL2NormalizationFloat32")
Matteo Martincighb63973e2018-10-16 16:23:33 +0100984{
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100985 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
Matteo Martincighb63973e2018-10-16 16:23:33 +0100986}
987
Sadik Armagan1625efc2021-06-10 18:24:34 +0100988TEST_CASE("CreateL2NormalizationFloat32Nhwc")
Matteo Martincighb63973e2018-10-16 16:23:33 +0100989{
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100990 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
991}
992
Sadik Armagan1625efc2021-06-10 18:24:34 +0100993TEST_CASE("CreateL2NormalizationInt16")
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100994{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000995 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QSymmS16>(DataLayout::NCHW);
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100996}
997
Sadik Armagan1625efc2021-06-10 18:24:34 +0100998TEST_CASE("CreateL2NormalizationInt16Nhwc")
Ferran Balaguerd73d14f2019-06-10 10:29:54 +0100999{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001000 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QSymmS16>(DataLayout::NHWC);
telsoa014fcda012018-03-09 14:13:49 +00001001}
1002
Sadik Armagan1625efc2021-06-10 18:24:34 +01001003TEST_CASE("CreateL2NormalizationUint8")
Ferran Balaguerc6138d82019-06-13 17:23:50 +01001004{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001005 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QAsymmU8>(DataLayout::NCHW);
Ferran Balaguerc6138d82019-06-13 17:23:50 +01001006}
1007
Sadik Armagan1625efc2021-06-10 18:24:34 +01001008TEST_CASE("CreateL2NormalizationUint8Nhwc")
Ferran Balaguerc6138d82019-06-13 17:23:50 +01001009{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001010 RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QAsymmU8>(DataLayout::NHWC);
Ferran Balaguerc6138d82019-06-13 17:23:50 +01001011}
1012
telsoa01c577f2c2018-08-31 09:22:23 +01001013template <typename ReshapeWorkloadType, armnn::DataType DataType>
telsoa014fcda012018-03-09 14:13:49 +00001014static void RefCreateReshapeWorkloadTest()
1015{
1016 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +01001017 RefWorkloadFactory factory = GetFactory();
telsoa01c577f2c2018-08-31 09:22:23 +01001018 auto workload = CreateReshapeWorkloadTest<ReshapeWorkloadType, DataType>(factory, graph);
telsoa014fcda012018-03-09 14:13:49 +00001019
telsoa01c577f2c2018-08-31 09:22:23 +01001020 // Checks that outputs and inputs are as we expect them (see definition of CreateReshapeWorkloadTest).
telsoa014fcda012018-03-09 14:13:49 +00001021 CheckInputOutput(
1022 std::move(workload),
telsoa01c577f2c2018-08-31 09:22:23 +01001023 TensorInfo({ 4, 1 }, DataType),
1024 TensorInfo({ 1, 4 }, DataType));
telsoa014fcda012018-03-09 14:13:49 +00001025}
1026
Sadik Armagan1625efc2021-06-10 18:24:34 +01001027TEST_CASE("CreateReshapeWorkloadFloat32")
telsoa014fcda012018-03-09 14:13:49 +00001028{
Nina Drozd2f2778f2019-05-27 10:37:05 +01001029 RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::Float32>();
telsoa014fcda012018-03-09 14:13:49 +00001030}
1031
Sadik Armagan1625efc2021-06-10 18:24:34 +01001032TEST_CASE("CreateReshapeWorkloadQuantisedAsymm8")
telsoa014fcda012018-03-09 14:13:49 +00001033{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001034 RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::QAsymmU8>();
telsoa014fcda012018-03-09 14:13:49 +00001035}
1036
Sadik Armagan1625efc2021-06-10 18:24:34 +01001037TEST_CASE("CreateReshapeWorkloadQuantisedSymm16")
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001038{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001039 RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::QSymmS16>();
Nina Drozd8ed4b8c2019-05-29 10:41:04 +01001040}
1041
Jim Flynne242f2d2019-05-22 14:24:13 +01001042template <typename ConcatWorkloadType, armnn::DataType DataType>
1043static void RefCreateConcatWorkloadTest(const armnn::TensorShape& outputShape,
narpra015cdda352018-11-19 15:30:27 +00001044 unsigned int concatAxis)
1045{
1046 Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +01001047 RefWorkloadFactory factory = GetFactory();
Jim Flynne242f2d2019-05-22 14:24:13 +01001048 auto workload = CreateConcatWorkloadTest<ConcatWorkloadType, DataType>(factory, graph, outputShape, concatAxis);
narpra015cdda352018-11-19 15:30:27 +00001049
1050 CheckInputsOutput(std::move(workload),
1051 TensorInfo({ 2, 3, 2, 5 }, DataType),
1052 TensorInfo({ 2, 3, 2, 5 }, DataType),
1053 TensorInfo(outputShape, DataType));
1054}
1055
Sadik Armagan1625efc2021-06-10 18:24:34 +01001056TEST_CASE("CreateConcatDim0Float32Workload")
narpra015cdda352018-11-19 15:30:27 +00001057{
Jim Flynne242f2d2019-05-22 14:24:13 +01001058 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 4, 3, 2, 5 }, 0);
narpra015cdda352018-11-19 15:30:27 +00001059}
1060
Sadik Armagan1625efc2021-06-10 18:24:34 +01001061TEST_CASE("CreateConcatDim0Float16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +01001062{
1063 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float16>({ 4, 3, 2, 5 }, 0);
1064}
1065
Sadik Armagan1625efc2021-06-10 18:24:34 +01001066TEST_CASE("CreateConcatDim0Uint8Workload")
narpra015cdda352018-11-19 15:30:27 +00001067{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001068 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QAsymmU8>({ 4, 3, 2, 5 }, 0);
Jim Flynncbb66aa2019-05-15 13:03:54 +01001069}
1070
Sadik Armagan1625efc2021-06-10 18:24:34 +01001071TEST_CASE("CreateConcatDim0Uint16Workload")
Jim Flynncbb66aa2019-05-15 13:03:54 +01001072{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001073 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QSymmS16>({ 4, 3, 2, 5 }, 0);
narpra015cdda352018-11-19 15:30:27 +00001074}
1075
Sadik Armagan1625efc2021-06-10 18:24:34 +01001076TEST_CASE("CreateConcatDim1Float32Workload")
narpra015cdda352018-11-19 15:30:27 +00001077{
Jim Flynne242f2d2019-05-22 14:24:13 +01001078 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 6, 2, 5 }, 1);
narpra015cdda352018-11-19 15:30:27 +00001079}
1080
Sadik Armagan1625efc2021-06-10 18:24:34 +01001081TEST_CASE("CreateConcatDim1Uint8Workload")
narpra015cdda352018-11-19 15:30:27 +00001082{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001083 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QAsymmU8>({ 2, 6, 2, 5 }, 1);
narpra015cdda352018-11-19 15:30:27 +00001084}
1085
Sadik Armagan1625efc2021-06-10 18:24:34 +01001086TEST_CASE("CreateConcatDim2Float32Workload")
narpra015cdda352018-11-19 15:30:27 +00001087{
Jim Flynne242f2d2019-05-22 14:24:13 +01001088 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 3, 4, 5 }, 2);
narpra015cdda352018-11-19 15:30:27 +00001089}
1090
Sadik Armagan1625efc2021-06-10 18:24:34 +01001091TEST_CASE("CreateConcatDim2Uint8Workload")
narpra015cdda352018-11-19 15:30:27 +00001092{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001093 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QAsymmU8>({ 2, 3, 4, 5 }, 2);
narpra015cdda352018-11-19 15:30:27 +00001094}
1095
Sadik Armagan1625efc2021-06-10 18:24:34 +01001096TEST_CASE("CreateConcatDim3Float32Workload")
narpra015cdda352018-11-19 15:30:27 +00001097{
Jim Flynne242f2d2019-05-22 14:24:13 +01001098 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::Float32>({ 2, 3, 2, 10 }, 3);
narpra015cdda352018-11-19 15:30:27 +00001099}
1100
Sadik Armagan1625efc2021-06-10 18:24:34 +01001101TEST_CASE("CreateConcatDim3Uint8Workload")
narpra015cdda352018-11-19 15:30:27 +00001102{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001103 RefCreateConcatWorkloadTest<RefConcatWorkload, armnn::DataType::QAsymmU8>({ 2, 3, 2, 10 }, 3);
narpra015cdda352018-11-19 15:30:27 +00001104}
1105
Nina Drozd58ef2c62019-05-16 12:09:18 +01001106template <typename ConstantWorkloadType, armnn::DataType DataType>
1107static void RefCreateConstantWorkloadTest(const armnn::TensorShape& outputShape)
1108{
1109 armnn::Graph graph;
Matthew Bentham7c1603a2019-06-21 17:22:23 +01001110 RefWorkloadFactory factory = GetFactory();
Nina Drozd58ef2c62019-05-16 12:09:18 +01001111 auto workload = CreateConstantWorkloadTest<ConstantWorkloadType, DataType>(factory, graph, outputShape);
1112
1113 // Check output is as expected
1114 auto queueDescriptor = workload->GetData();
Jan Eilersbb446e52020-04-02 13:56:54 +01001115 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +01001116 CHECK((outputHandle->GetTensorInfo() == TensorInfo(outputShape, DataType)));
Nina Drozd58ef2c62019-05-16 12:09:18 +01001117}
1118
Sadik Armagan1625efc2021-06-10 18:24:34 +01001119TEST_CASE("CreateConstantUint8Workload")
Nina Drozd58ef2c62019-05-16 12:09:18 +01001120{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001121 RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::QAsymmU8>({ 2, 3, 2, 10 });
Nina Drozd58ef2c62019-05-16 12:09:18 +01001122}
1123
Sadik Armagan1625efc2021-06-10 18:24:34 +01001124TEST_CASE("CreateConstantInt16Workload")
Nina Drozd58ef2c62019-05-16 12:09:18 +01001125{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001126 RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::QSymmS16>({ 2, 3, 2, 10 });
Nina Drozd58ef2c62019-05-16 12:09:18 +01001127}
1128
Sadik Armagan1625efc2021-06-10 18:24:34 +01001129TEST_CASE("CreateConstantFloat32Workload")
Nina Drozd58ef2c62019-05-16 12:09:18 +01001130{
1131 RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::Float32>({ 2, 3, 2, 10 });
1132}
1133
Sadik Armagan1625efc2021-06-10 18:24:34 +01001134TEST_CASE("CreateConstantSigned32Workload")
Nina Drozd58ef2c62019-05-16 12:09:18 +01001135{
1136 RefCreateConstantWorkloadTest<RefConstantWorkload, armnn::DataType::Signed32>({ 2, 3, 2, 10 });
1137}
1138
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001139static void RefCreatePreluWorkloadTest(const armnn::TensorShape& inputShape,
1140 const armnn::TensorShape& alphaShape,
1141 const armnn::TensorShape& outputShape,
1142 armnn::DataType dataType)
Matteo Martincighab9e5252019-06-13 17:27:46 +01001143{
1144 armnn::Graph graph;
1145 RefWorkloadFactory factory;
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001146 auto workload = CreatePreluWorkloadTest<RefPreluWorkload>(factory,
1147 graph,
1148 inputShape,
1149 alphaShape,
1150 outputShape,
1151 dataType);
Matteo Martincighab9e5252019-06-13 17:27:46 +01001152
1153 // Check output is as expected
1154 auto queueDescriptor = workload->GetData();
Jan Eilersbb446e52020-04-02 13:56:54 +01001155 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +01001156 CHECK((outputHandle->GetTensorInfo() == TensorInfo(outputShape, dataType)));
Matteo Martincighab9e5252019-06-13 17:27:46 +01001157}
1158
Sadik Armagan1625efc2021-06-10 18:24:34 +01001159TEST_CASE("CreatePreluFloat32Workload")
Matteo Martincighab9e5252019-06-13 17:27:46 +01001160{
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001161 RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::Float32);
Matteo Martincighab9e5252019-06-13 17:27:46 +01001162}
1163
Sadik Armagan1625efc2021-06-10 18:24:34 +01001164TEST_CASE("CreatePreluFloat16Workload")
Matthew Jackson9bff1442019-09-12 09:08:23 +01001165{
1166 RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::Float16);
1167}
1168
Sadik Armagan1625efc2021-06-10 18:24:34 +01001169TEST_CASE("CreatePreluUint8Workload")
Matteo Martincighab9e5252019-06-13 17:27:46 +01001170{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001171 RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::QAsymmU8);
Matteo Martincighab9e5252019-06-13 17:27:46 +01001172}
1173
Sadik Armagan1625efc2021-06-10 18:24:34 +01001174TEST_CASE("CreatePreluInt16Workload")
Matteo Martincighab9e5252019-06-13 17:27:46 +01001175{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001176 RefCreatePreluWorkloadTest({ 1, 4, 1, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 }, armnn::DataType::QSymmS16);
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001177}
1178
Sadik Armagan1625efc2021-06-10 18:24:34 +01001179TEST_CASE("CreatePreluFloat32NoBroadcastWorkload")
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001180{
Sadik Armagan1625efc2021-06-10 18:24:34 +01001181 CHECK_THROWS_AS(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001182 armnn::DataType::Float32),
1183 armnn::InvalidArgumentException);
1184}
1185
Sadik Armagan1625efc2021-06-10 18:24:34 +01001186TEST_CASE("CreatePreluFloat16NoBroadcastWorkload")
Matthew Jackson9bff1442019-09-12 09:08:23 +01001187{
Sadik Armagan1625efc2021-06-10 18:24:34 +01001188 CHECK_THROWS_AS(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
Matthew Jackson9bff1442019-09-12 09:08:23 +01001189 armnn::DataType::Float16),
1190 armnn::InvalidArgumentException);
1191}
1192
Sadik Armagan1625efc2021-06-10 18:24:34 +01001193TEST_CASE("CreatePreluUint8NoBroadcastWorkload")
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001194{
Sadik Armagan1625efc2021-06-10 18:24:34 +01001195 CHECK_THROWS_AS(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
Derek Lambertif90c56d2020-01-10 17:14:08 +00001196 armnn::DataType::QAsymmU8),
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001197 armnn::InvalidArgumentException);
1198}
1199
Sadik Armagan1625efc2021-06-10 18:24:34 +01001200TEST_CASE("CreatePreluInt16NoBroadcastWorkload")
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001201{
Sadik Armagan1625efc2021-06-10 18:24:34 +01001202 CHECK_THROWS_AS(RefCreatePreluWorkloadTest({ 1, 4, 7, 2 }, { 5, 4, 3, 1 }, { 5, 4, 3, 2 },
Derek Lambertif90c56d2020-01-10 17:14:08 +00001203 armnn::DataType::QSymmS16),
Matteo Martincighbf0e7222019-06-20 17:17:45 +01001204 armnn::InvalidArgumentException);
Matteo Martincighab9e5252019-06-13 17:27:46 +01001205}
1206
James Conroy60597842019-07-02 10:57:56 +01001207template <typename SpaceToDepthWorkloadType, armnn::DataType DataType>
1208static void RefCreateSpaceToDepthWorkloadTest()
1209{
1210 Graph graph;
1211 RefWorkloadFactory factory;
1212
1213 auto workload = CreateSpaceToDepthWorkloadTest<SpaceToDepthWorkloadType, DataType>(factory, graph);
1214
1215 CheckInputOutput(std::move(workload),
1216 TensorInfo({ 1, 2, 2, 1 }, DataType),
1217 TensorInfo({ 1, 1, 1, 4 }, DataType));
1218}
1219
Sadik Armagan1625efc2021-06-10 18:24:34 +01001220TEST_CASE("CreateSpaceToDepthWorkloadFloat32")
James Conroy60597842019-07-02 10:57:56 +01001221{
1222 RefCreateSpaceToDepthWorkloadTest<RefSpaceToDepthWorkload, armnn::DataType::Float32>();
1223}
1224
Sadik Armagan1625efc2021-06-10 18:24:34 +01001225TEST_CASE("CreateSpaceToDepthWorkloadFloat16")
Matthew Jackson9bff1442019-09-12 09:08:23 +01001226{
1227 RefCreateSpaceToDepthWorkloadTest<RefSpaceToDepthWorkload, armnn::DataType::Float16>();
1228}
1229
Sadik Armagan1625efc2021-06-10 18:24:34 +01001230TEST_CASE("CreateSpaceToDepthWorkloadQASymm8")
James Conroy60597842019-07-02 10:57:56 +01001231{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001232 RefCreateSpaceToDepthWorkloadTest<RefSpaceToDepthWorkload, armnn::DataType::QAsymmU8>();
James Conroy60597842019-07-02 10:57:56 +01001233}
1234
Sadik Armagan1625efc2021-06-10 18:24:34 +01001235TEST_CASE("CreateSpaceToDepthWorkloadQSymm16")
James Conroy60597842019-07-02 10:57:56 +01001236{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001237 RefCreateSpaceToDepthWorkloadTest<RefSpaceToDepthWorkload, armnn::DataType::QSymmS16>();
James Conroy60597842019-07-02 10:57:56 +01001238}
1239
Matthew Jacksond5166102019-07-31 14:06:28 +01001240template <armnn::DataType DataType>
Matthew Jackson81e601c2019-07-11 12:07:09 +01001241static void RefCreateStackWorkloadTest(const armnn::TensorShape& inputShape,
1242 const armnn::TensorShape& outputShape,
1243 unsigned int axis,
Matthew Jacksond5166102019-07-31 14:06:28 +01001244 unsigned int numInputs)
Matthew Jackson81e601c2019-07-11 12:07:09 +01001245{
1246 armnn::Graph graph;
1247 RefWorkloadFactory factory;
Matthew Jacksond5166102019-07-31 14:06:28 +01001248 auto workload = CreateStackWorkloadTest<RefStackWorkload, DataType>(factory,
1249 graph,
1250 inputShape,
1251 outputShape,
1252 axis,
1253 numInputs);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001254
Matthew Jacksond5166102019-07-31 14:06:28 +01001255 // Check inputs and output are as expected
1256 StackQueueDescriptor queueDescriptor = workload->GetData();
1257 for (unsigned int i = 0; i < numInputs; ++i)
1258 {
Jan Eilersbb446e52020-04-02 13:56:54 +01001259 auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[i]);
Sadik Armagan1625efc2021-06-10 18:24:34 +01001260 CHECK((inputHandle->GetTensorInfo() == TensorInfo(inputShape, DataType)));
Matthew Jacksond5166102019-07-31 14:06:28 +01001261 }
Jan Eilersbb446e52020-04-02 13:56:54 +01001262 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +01001263 CHECK((outputHandle->GetTensorInfo() == TensorInfo(outputShape, DataType)));
Matthew Jackson81e601c2019-07-11 12:07:09 +01001264}
1265
Sadik Armagan1625efc2021-06-10 18:24:34 +01001266TEST_CASE("CreateStackFloat32Workload")
Matthew Jackson81e601c2019-07-11 12:07:09 +01001267{
Matthew Jacksond5166102019-07-31 14:06:28 +01001268 RefCreateStackWorkloadTest<armnn::DataType::Float32>({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001269}
1270
Sadik Armagan1625efc2021-06-10 18:24:34 +01001271TEST_CASE("CreateStackUint8Workload")
Matthew Jackson81e601c2019-07-11 12:07:09 +01001272{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001273 RefCreateStackWorkloadTest<armnn::DataType::QAsymmU8>({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001274}
1275
Sadik Armagan1625efc2021-06-10 18:24:34 +01001276TEST_CASE("CreateStackUint16Workload")
Matthew Jackson81e601c2019-07-11 12:07:09 +01001277{
Derek Lambertif90c56d2020-01-10 17:14:08 +00001278 RefCreateStackWorkloadTest<armnn::DataType::QSymmS16>({ 3, 4, 5 }, { 3, 4, 2, 5 }, 2, 2);
Matthew Jackson81e601c2019-07-11 12:07:09 +01001279}
1280
James Conroy4f1f8992020-04-29 20:01:10 +01001281template <typename QLstmWorkloadType>
1282static void RefCreateQLstmWorkloadTest()
1283{
1284 Graph graph;
1285 RefWorkloadFactory factory;
1286
1287 auto workload = CreateQLstmWorkloadTest<QLstmWorkloadType>(factory, graph);
1288
1289 armnn::TensorInfo inputInfo({2 , 4}, armnn::DataType::QAsymmS8, 0.0078125f, 0);
1290
1291 armnn::TensorInfo cellStateInfo({2 , 4}, armnn::DataType::QSymmS16, 3.05176e-05f, 0);
1292
1293 armnn::TensorInfo outputInfo({2 , 4}, armnn::DataType::QAsymmS8, 0.007f, 0);
1294
1295 QLstmQueueDescriptor queueDescriptor = workload->GetData();
Jan Eilersaaf9a8f2020-07-01 16:35:35 +01001296 auto inputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
1297 auto cellStateOutHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[1]);
1298 auto outputHandle = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[2]);
James Conroy4f1f8992020-04-29 20:01:10 +01001299
Sadik Armagan1625efc2021-06-10 18:24:34 +01001300 CHECK((inputHandle->GetTensorInfo() == inputInfo));
1301 CHECK((cellStateOutHandle->GetTensorInfo() == cellStateInfo));
1302 CHECK((outputHandle->GetTensorInfo() == outputInfo));
James Conroy4f1f8992020-04-29 20:01:10 +01001303}
1304
Sadik Armagan1625efc2021-06-10 18:24:34 +01001305TEST_CASE("CreateQLstmWorkload")
James Conroy4f1f8992020-04-29 20:01:10 +01001306{
1307 RefCreateQLstmWorkloadTest<RefQLstmWorkload>();
1308}
1309
Teresa Charlin788e2a62022-01-17 21:19:52 +00001310template <armnn::DataType DataType>
1311static void RefCreateActivationWorkloadReplaceFunctionsTest()
1312{
1313 Graph graph;
1314 RefWorkloadFactory factory = GetFactory();
1315 // input and output are created as armnn::TensorInfo tensorInfo({1, 1}, DataType)
1316 auto workloadPtr = CreateActivationWorkloadTest<RefActivationWorkload, DataType>(factory, graph);
1317
1318 // new input and output tensor handlers are created and then replace in the workload
1319 shared_ptr<RefMemoryManager> memoryManager = make_shared<RefMemoryManager>();
1320 const RefTensorHandleFactory tensorHandleFactory(memoryManager);
Rob Hughes5bcc0722022-01-21 10:56:14 +00001321 TensorInfo inputInfo({2 , 2}, armnn::DataType::Float16);
1322 TensorInfo outputInfo({2 , 2}, armnn::DataType::Float16);
Teresa Charlin788e2a62022-01-17 21:19:52 +00001323 unique_ptr<ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
1324 unique_ptr<ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo);
1325 unsigned int slot = 0;
1326 workloadPtr->ReplaceInputTensorHandle(inputHandle.get(), slot);
1327 workloadPtr->ReplaceOutputTensorHandle(outputHandle.get(), slot);
1328
1329 // Check if the tensor handlers inside the workload are the same as ones we replace with
1330 auto queueDescriptor = workloadPtr->GetData();
1331 auto inputHandleTest = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Inputs[0]);
1332 auto outputHandleTest = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor.m_Outputs[0]);
1333 CHECK((inputHandleTest->GetTensorInfo() == inputInfo));
1334 CHECK((outputHandleTest->GetTensorInfo() == outputInfo));
1335 CHECK(inputHandle.get() == inputHandleTest);
1336 CHECK(outputHandle.get() == outputHandleTest);
1337 inputHandle->Allocate();
1338 CHECK(inputHandle->Map() == inputHandleTest->Map());
1339 outputHandle->Allocate();
1340 CHECK(outputHandle->Map() == outputHandleTest->Map());
1341}
1342
1343TEST_CASE("ReplaceFunctionsfromFloat32toFloat16ActivationWorkload")
1344{
1345 RefCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::Float32>();
1346}
1347
1348TEST_CASE("ReplaceFunctionsfromUint8toFloat16ActivationWorkload")
1349{
1350 RefCreateActivationWorkloadReplaceFunctionsTest<armnn::DataType::QAsymmU8>();
1351}
1352
Sadik Armagan1625efc2021-06-10 18:24:34 +01001353}