blob: a7783aa650144dbcdbc6acb600fedef493c81612 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Sadik Armagana097d2a2021-11-24 15:47:28 +00007#include <CreateWorkload.hpp>
8#include <armnnTestUtils/PredicateResult.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +01009#include <armnn/utility/PolymorphicDowncast.hpp>
Colm Donelan0c479742021-12-10 12:43:54 +000010#include <armnn/backends/MemCopyWorkload.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000011#include <reference/RefWorkloadFactory.hpp>
Matthew Bentham4cefc412019-06-18 16:14:34 +010012#include <reference/RefTensorHandle.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Matteo Martincighd95e9062019-01-31 15:35:59 +000014#if defined(ARMCOMPUTECL_ENABLED)
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000015#include <cl/ClTensorHandle.hpp>
telsoa014fcda012018-03-09 14:13:49 +000016#endif
17
Matteo Martincighd95e9062019-01-31 15:35:59 +000018#if defined(ARMCOMPUTENEON_ENABLED)
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000019#include <neon/NeonTensorHandle.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020#endif
21
Sadik Armagan1625efc2021-06-10 18:24:34 +010022#include <doctest/doctest.h>
23
telsoa014fcda012018-03-09 14:13:49 +000024using namespace armnn;
25
26namespace
27{
28
29using namespace std;
30
31template<typename IComputeTensorHandle>
Colm Donelan25ab3a82021-05-17 13:01:52 +010032PredicateResult CompareTensorHandleShape(IComputeTensorHandle* tensorHandle,
33 std::initializer_list<unsigned int> expectedDimensions)
telsoa014fcda012018-03-09 14:13:49 +000034{
35 arm_compute::ITensorInfo* info = tensorHandle->GetTensor().info();
36
37 auto infoNumDims = info->num_dimensions();
38 auto numExpectedDims = expectedDimensions.size();
39 if (infoNumDims != numExpectedDims)
40 {
Colm Donelan25ab3a82021-05-17 13:01:52 +010041 PredicateResult res(false);
42 res.Message() << "Different number of dimensions [" << info->num_dimensions()
telsoa014fcda012018-03-09 14:13:49 +000043 << "!=" << expectedDimensions.size() << "]";
44 return res;
45 }
46
47 size_t i = info->num_dimensions() - 1;
48
49 for (unsigned int expectedDimension : expectedDimensions)
50 {
51 if (info->dimension(i) != expectedDimension)
52 {
Colm Donelan25ab3a82021-05-17 13:01:52 +010053 PredicateResult res(false);
54 res.Message() << "For dimension " << i <<
Matthew Bentham89105282018-11-20 14:33:33 +000055 " expected size " << expectedDimension <<
56 " got " << info->dimension(i);
telsoa014fcda012018-03-09 14:13:49 +000057 return res;
58 }
59
60 i--;
61 }
62
Colm Donelan25ab3a82021-05-17 13:01:52 +010063 return PredicateResult(true);
telsoa014fcda012018-03-09 14:13:49 +000064}
65
telsoa01c577f2c2018-08-31 09:22:23 +010066template<typename IComputeTensorHandle>
telsoa014fcda012018-03-09 14:13:49 +000067void CreateMemCopyWorkloads(IWorkloadFactory& factory)
68{
Derek Lamberti84da38b2019-06-13 11:40:08 +010069 TensorHandleFactoryRegistry registry;
telsoa014fcda012018-03-09 14:13:49 +000070 Graph graph;
71 RefWorkloadFactory refFactory;
72
telsoa01c577f2c2018-08-31 09:22:23 +010073 // Creates the layers we're testing.
telsoa014fcda012018-03-09 14:13:49 +000074 Layer* const layer1 = graph.AddLayer<MemCopyLayer>("layer1");
75 Layer* const layer2 = graph.AddLayer<MemCopyLayer>("layer2");
76
telsoa01c577f2c2018-08-31 09:22:23 +010077 // Creates extra layers.
telsoa014fcda012018-03-09 14:13:49 +000078 Layer* const input = graph.AddLayer<InputLayer>(0, "input");
79 Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
80
telsoa01c577f2c2018-08-31 09:22:23 +010081 // Connects up.
telsoa014fcda012018-03-09 14:13:49 +000082 TensorInfo tensorInfo({2, 3}, DataType::Float32);
83 Connect(input, layer1, tensorInfo);
84 Connect(layer1, layer2, tensorInfo);
85 Connect(layer2, output, tensorInfo);
86
Derek Lamberti84da38b2019-06-13 11:40:08 +010087 input->CreateTensorHandles(registry, refFactory);
88 layer1->CreateTensorHandles(registry, factory);
89 layer2->CreateTensorHandles(registry, refFactory);
90 output->CreateTensorHandles(registry, refFactory);
telsoa014fcda012018-03-09 14:13:49 +000091
92 // make the workloads and check them
Derek Lamberti94a88d22019-12-10 21:12:59 +000093 auto workload1 = MakeAndCheckWorkload<CopyMemGenericWorkload>(*layer1, factory);
94 auto workload2 = MakeAndCheckWorkload<CopyMemGenericWorkload>(*layer2, refFactory);
telsoa014fcda012018-03-09 14:13:49 +000095
96 MemCopyQueueDescriptor queueDescriptor1 = workload1->GetData();
Sadik Armagan1625efc2021-06-10 18:24:34 +010097 CHECK(queueDescriptor1.m_Inputs.size() == 1);
98 CHECK(queueDescriptor1.m_Outputs.size() == 1);
Jan Eilersbb446e52020-04-02 13:56:54 +010099 auto inputHandle1 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor1.m_Inputs[0]);
100 auto outputHandle1 = PolymorphicDowncast<IComputeTensorHandle*>(queueDescriptor1.m_Outputs[0]);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100101 CHECK((inputHandle1->GetTensorInfo() == TensorInfo({2, 3}, DataType::Float32)));
Colm Donelan25ab3a82021-05-17 13:01:52 +0100102 auto result = CompareTensorHandleShape<IComputeTensorHandle>(outputHandle1, {2, 3});
Sadik Armagan1625efc2021-06-10 18:24:34 +0100103 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
telsoa014fcda012018-03-09 14:13:49 +0000104
105
106 MemCopyQueueDescriptor queueDescriptor2 = workload2->GetData();
Sadik Armagan1625efc2021-06-10 18:24:34 +0100107 CHECK(queueDescriptor2.m_Inputs.size() == 1);
108 CHECK(queueDescriptor2.m_Outputs.size() == 1);
Jan Eilersbb446e52020-04-02 13:56:54 +0100109 auto inputHandle2 = PolymorphicDowncast<IComputeTensorHandle*>(queueDescriptor2.m_Inputs[0]);
110 auto outputHandle2 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor2.m_Outputs[0]);
Colm Donelan25ab3a82021-05-17 13:01:52 +0100111 result = CompareTensorHandleShape<IComputeTensorHandle>(inputHandle2, {2, 3});
Sadik Armagan1625efc2021-06-10 18:24:34 +0100112 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
113 CHECK((outputHandle2->GetTensorInfo() == TensorInfo({2, 3}, DataType::Float32)));
telsoa014fcda012018-03-09 14:13:49 +0000114}
115
Matthew Bentham89105282018-11-20 14:33:33 +0000116} //namespace