blob: a2076722960332938cb959bcc5ad4c3f9799b619 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +01002// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +01007#include <ResolveType.hpp>
Matteo Martincigh49124022019-01-11 13:25:59 +00008
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00009#include <armnn/backends/IBackendInternal.hpp>
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +000010
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000011#include <backendsCommon/test/LayerTests.hpp>
12#include <backendsCommon/test/TensorCopyUtils.hpp>
Aron Virginas-Tar56055192018-11-12 18:10:43 +000013#include <backendsCommon/test/WorkloadFactoryHelper.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014#include <backendsCommon/test/WorkloadTestUtils.hpp>
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010015
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <test/TensorHelpers.hpp>
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010017
18#include <boost/multi_array.hpp>
19
20namespace
21{
22
Matteo Martincigh49124022019-01-11 13:25:59 +000023template<armnn::DataType dataType, typename T = armnn::ResolveType<dataType>>
24LayerTestResult<T, 4> MemCopyTest(armnn::IWorkloadFactory& srcWorkloadFactory,
25 armnn::IWorkloadFactory& dstWorkloadFactory,
26 bool withSubtensors)
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010027{
28 const std::array<unsigned int, 4> shapeData = { { 1u, 1u, 6u, 5u } };
29 const armnn::TensorShape tensorShape(4, shapeData.data());
Matteo Martincigh49124022019-01-11 13:25:59 +000030 const armnn::TensorInfo tensorInfo(tensorShape, dataType);
31 boost::multi_array<T, 4> inputData = MakeTensor<T, 4>(tensorInfo, std::vector<T>(
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010032 {
Matteo Martincigh49124022019-01-11 13:25:59 +000033 1, 2, 3, 4, 5,
34 6, 7, 8, 9, 10,
35 11, 12, 13, 14, 15,
36 16, 17, 18, 19, 20,
37 21, 22, 23, 24, 25,
38 26, 27, 28, 29, 30,
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010039 })
40 );
41
Matteo Martincigh49124022019-01-11 13:25:59 +000042 LayerTestResult<T, 4> ret(tensorInfo);
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010043 ret.outputExpected = inputData;
44
Matteo Martincigh49124022019-01-11 13:25:59 +000045 boost::multi_array<T, 4> outputData(shapeData);
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010046
47 auto inputTensorHandle = srcWorkloadFactory.CreateTensorHandle(tensorInfo);
48 auto outputTensorHandle = dstWorkloadFactory.CreateTensorHandle(tensorInfo);
49
50 AllocateAndCopyDataToITensorHandle(inputTensorHandle.get(), inputData.data());
51 outputTensorHandle->Allocate();
52
53 armnn::MemCopyQueueDescriptor memCopyQueueDesc;
54 armnn::WorkloadInfo workloadInfo;
55
56 const unsigned int origin[4] = {};
57
58 auto workloadInput = (withSubtensors && srcWorkloadFactory.SupportsSubTensors())
59 ? srcWorkloadFactory.CreateSubTensorHandle(*inputTensorHandle, tensorShape, origin)
60 : std::move(inputTensorHandle);
61 auto workloadOutput = (withSubtensors && dstWorkloadFactory.SupportsSubTensors())
62 ? dstWorkloadFactory.CreateSubTensorHandle(*outputTensorHandle, tensorShape, origin)
63 : std::move(outputTensorHandle);
64
65 AddInputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadInput.get());
66 AddOutputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadOutput.get());
67
68 dstWorkloadFactory.CreateMemCopy(memCopyQueueDesc, workloadInfo)->Execute();
69
70 CopyDataFromITensorHandle(outputData.data(), workloadOutput.get());
71 ret.output = outputData;
72
73 return ret;
74}
75
Matteo Martincigh49124022019-01-11 13:25:59 +000076template<typename SrcWorkloadFactory,
77 typename DstWorkloadFactory,
78 armnn::DataType dataType,
79 typename T = armnn::ResolveType<dataType>>
80LayerTestResult<T, 4> MemCopyTest(bool withSubtensors)
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010081{
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +000082 armnn::IBackendInternal::IMemoryManagerSharedPtr srcMemoryManager =
83 WorkloadFactoryHelper<SrcWorkloadFactory>::GetMemoryManager();
84
85 armnn::IBackendInternal::IMemoryManagerSharedPtr dstMemoryManager =
86 WorkloadFactoryHelper<DstWorkloadFactory>::GetMemoryManager();
87
88 SrcWorkloadFactory srcWorkloadFactory = WorkloadFactoryHelper<SrcWorkloadFactory>::GetFactory(srcMemoryManager);
89 DstWorkloadFactory dstWorkloadFactory = WorkloadFactoryHelper<DstWorkloadFactory>::GetFactory(dstMemoryManager);
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010090
Matteo Martincigh49124022019-01-11 13:25:59 +000091 return MemCopyTest<dataType>(srcWorkloadFactory, dstWorkloadFactory, withSubtensors);
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010092}
93
94} // anonymous namespace