blob: 53d4dc9154e71c414e0f4ddde9352802f160de60 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include <test/CreateWorkload.hpp>
arovir0143095f32018-10-09 18:04:24 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <backendsCommon/MemCopyWorkload.hpp>
10#include <reference/RefWorkloadFactory.hpp>
Matthew Bentham4cefc412019-06-18 16:14:34 +010011#include <reference/RefTensorHandle.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Matteo Martincighd95e9062019-01-31 15:35:59 +000013#if defined(ARMCOMPUTECL_ENABLED)
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014#include <cl/ClTensorHandle.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015#endif
16
Matteo Martincighd95e9062019-01-31 15:35:59 +000017#if defined(ARMCOMPUTENEON_ENABLED)
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000018#include <neon/NeonTensorHandle.hpp>
telsoa014fcda012018-03-09 14:13:49 +000019#endif
20
telsoa014fcda012018-03-09 14:13:49 +000021using namespace armnn;
22
23namespace
24{
25
26using namespace std;
27
28template<typename IComputeTensorHandle>
29boost::test_tools::predicate_result CompareTensorHandleShape(IComputeTensorHandle* tensorHandle,
30 std::initializer_list<unsigned int> expectedDimensions)
31{
32 arm_compute::ITensorInfo* info = tensorHandle->GetTensor().info();
33
34 auto infoNumDims = info->num_dimensions();
35 auto numExpectedDims = expectedDimensions.size();
36 if (infoNumDims != numExpectedDims)
37 {
38 boost::test_tools::predicate_result res(false);
39 res.message() << "Different number of dimensions [" << info->num_dimensions()
40 << "!=" << expectedDimensions.size() << "]";
41 return res;
42 }
43
44 size_t i = info->num_dimensions() - 1;
45
46 for (unsigned int expectedDimension : expectedDimensions)
47 {
48 if (info->dimension(i) != expectedDimension)
49 {
50 boost::test_tools::predicate_result res(false);
Matthew Bentham89105282018-11-20 14:33:33 +000051 res.message() << "For dimension " << i <<
52 " expected size " << expectedDimension <<
53 " got " << info->dimension(i);
telsoa014fcda012018-03-09 14:13:49 +000054 return res;
55 }
56
57 i--;
58 }
59
60 return true;
61}
62
telsoa01c577f2c2018-08-31 09:22:23 +010063template<typename IComputeTensorHandle>
telsoa014fcda012018-03-09 14:13:49 +000064void CreateMemCopyWorkloads(IWorkloadFactory& factory)
65{
Derek Lamberti84da38b2019-06-13 11:40:08 +010066 TensorHandleFactoryRegistry registry;
telsoa014fcda012018-03-09 14:13:49 +000067 Graph graph;
68 RefWorkloadFactory refFactory;
69
telsoa01c577f2c2018-08-31 09:22:23 +010070 // Creates the layers we're testing.
telsoa014fcda012018-03-09 14:13:49 +000071 Layer* const layer1 = graph.AddLayer<MemCopyLayer>("layer1");
72 Layer* const layer2 = graph.AddLayer<MemCopyLayer>("layer2");
73
telsoa01c577f2c2018-08-31 09:22:23 +010074 // Creates extra layers.
telsoa014fcda012018-03-09 14:13:49 +000075 Layer* const input = graph.AddLayer<InputLayer>(0, "input");
76 Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
77
telsoa01c577f2c2018-08-31 09:22:23 +010078 // Connects up.
telsoa014fcda012018-03-09 14:13:49 +000079 TensorInfo tensorInfo({2, 3}, DataType::Float32);
80 Connect(input, layer1, tensorInfo);
81 Connect(layer1, layer2, tensorInfo);
82 Connect(layer2, output, tensorInfo);
83
Derek Lamberti84da38b2019-06-13 11:40:08 +010084 input->CreateTensorHandles(registry, refFactory);
85 layer1->CreateTensorHandles(registry, factory);
86 layer2->CreateTensorHandles(registry, refFactory);
87 output->CreateTensorHandles(registry, refFactory);
telsoa014fcda012018-03-09 14:13:49 +000088
89 // make the workloads and check them
telsoa01c577f2c2018-08-31 09:22:23 +010090 auto workload1 = MakeAndCheckWorkload<CopyMemGenericWorkload>(*layer1, graph, factory);
91 auto workload2 = MakeAndCheckWorkload<CopyMemGenericWorkload>(*layer2, graph, refFactory);
telsoa014fcda012018-03-09 14:13:49 +000092
93 MemCopyQueueDescriptor queueDescriptor1 = workload1->GetData();
94 BOOST_TEST(queueDescriptor1.m_Inputs.size() == 1);
95 BOOST_TEST(queueDescriptor1.m_Outputs.size() == 1);
Matthew Bentham4cefc412019-06-18 16:14:34 +010096 auto inputHandle1 = boost::polymorphic_downcast<RefTensorHandle*>(queueDescriptor1.m_Inputs[0]);
telsoa014fcda012018-03-09 14:13:49 +000097 auto outputHandle1 = boost::polymorphic_downcast<IComputeTensorHandle*>(queueDescriptor1.m_Outputs[0]);
98 BOOST_TEST((inputHandle1->GetTensorInfo() == TensorInfo({2, 3}, DataType::Float32)));
99 BOOST_TEST(CompareTensorHandleShape<IComputeTensorHandle>(outputHandle1, {2, 3}));
100
101
102 MemCopyQueueDescriptor queueDescriptor2 = workload2->GetData();
103 BOOST_TEST(queueDescriptor2.m_Inputs.size() == 1);
104 BOOST_TEST(queueDescriptor2.m_Outputs.size() == 1);
105 auto inputHandle2 = boost::polymorphic_downcast<IComputeTensorHandle*>(queueDescriptor2.m_Inputs[0]);
Matthew Bentham4cefc412019-06-18 16:14:34 +0100106 auto outputHandle2 = boost::polymorphic_downcast<RefTensorHandle*>(queueDescriptor2.m_Outputs[0]);
telsoa014fcda012018-03-09 14:13:49 +0000107 BOOST_TEST(CompareTensorHandleShape<IComputeTensorHandle>(inputHandle2, {2, 3}));
108 BOOST_TEST((outputHandle2->GetTensorInfo() == TensorInfo({2, 3}, DataType::Float32)));
109}
110
Matthew Bentham89105282018-11-20 14:33:33 +0000111} //namespace