blob: 1f542d24b4566731225bb4829a931e7eff5bc532 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +01002// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +01007#include <ResolveType.hpp>
Matteo Martincigh49124022019-01-11 13:25:59 +00008
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00009#include <armnn/backends/IBackendInternal.hpp>
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +000010
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000011#include <backendsCommon/test/LayerTests.hpp>
12#include <backendsCommon/test/TensorCopyUtils.hpp>
Aron Virginas-Tar56055192018-11-12 18:10:43 +000013#include <backendsCommon/test/WorkloadFactoryHelper.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014#include <backendsCommon/test/WorkloadTestUtils.hpp>
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010015
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <test/TensorHelpers.hpp>
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010017
18#include <boost/multi_array.hpp>
19
20namespace
21{
22
Matteo Martincigh49124022019-01-11 13:25:59 +000023template<armnn::DataType dataType, typename T = armnn::ResolveType<dataType>>
24LayerTestResult<T, 4> MemCopyTest(armnn::IWorkloadFactory& srcWorkloadFactory,
25 armnn::IWorkloadFactory& dstWorkloadFactory,
26 bool withSubtensors)
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010027{
28 const std::array<unsigned int, 4> shapeData = { { 1u, 1u, 6u, 5u } };
29 const armnn::TensorShape tensorShape(4, shapeData.data());
Matteo Martincigh49124022019-01-11 13:25:59 +000030 const armnn::TensorInfo tensorInfo(tensorShape, dataType);
31 boost::multi_array<T, 4> inputData = MakeTensor<T, 4>(tensorInfo, std::vector<T>(
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010032 {
Matteo Martincigh49124022019-01-11 13:25:59 +000033 1, 2, 3, 4, 5,
34 6, 7, 8, 9, 10,
35 11, 12, 13, 14, 15,
36 16, 17, 18, 19, 20,
37 21, 22, 23, 24, 25,
38 26, 27, 28, 29, 30,
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010039 })
40 );
41
Matteo Martincigh49124022019-01-11 13:25:59 +000042 LayerTestResult<T, 4> ret(tensorInfo);
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010043 ret.outputExpected = inputData;
44
Matteo Martincigh49124022019-01-11 13:25:59 +000045 boost::multi_array<T, 4> outputData(shapeData);
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010046
Teresa Charlincf2d9132020-08-17 20:06:26 +010047 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010048 auto inputTensorHandle = srcWorkloadFactory.CreateTensorHandle(tensorInfo);
49 auto outputTensorHandle = dstWorkloadFactory.CreateTensorHandle(tensorInfo);
Teresa Charlincf2d9132020-08-17 20:06:26 +010050 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010051
52 AllocateAndCopyDataToITensorHandle(inputTensorHandle.get(), inputData.data());
53 outputTensorHandle->Allocate();
54
55 armnn::MemCopyQueueDescriptor memCopyQueueDesc;
56 armnn::WorkloadInfo workloadInfo;
57
58 const unsigned int origin[4] = {};
59
Teresa Charlincf2d9132020-08-17 20:06:26 +010060 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010061 auto workloadInput = (withSubtensors && srcWorkloadFactory.SupportsSubTensors())
62 ? srcWorkloadFactory.CreateSubTensorHandle(*inputTensorHandle, tensorShape, origin)
63 : std::move(inputTensorHandle);
64 auto workloadOutput = (withSubtensors && dstWorkloadFactory.SupportsSubTensors())
65 ? dstWorkloadFactory.CreateSubTensorHandle(*outputTensorHandle, tensorShape, origin)
66 : std::move(outputTensorHandle);
Teresa Charlincf2d9132020-08-17 20:06:26 +010067 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010068
69 AddInputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadInput.get());
70 AddOutputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadOutput.get());
71
72 dstWorkloadFactory.CreateMemCopy(memCopyQueueDesc, workloadInfo)->Execute();
73
74 CopyDataFromITensorHandle(outputData.data(), workloadOutput.get());
75 ret.output = outputData;
76
77 return ret;
78}
79
Matteo Martincigh49124022019-01-11 13:25:59 +000080template<typename SrcWorkloadFactory,
81 typename DstWorkloadFactory,
82 armnn::DataType dataType,
83 typename T = armnn::ResolveType<dataType>>
84LayerTestResult<T, 4> MemCopyTest(bool withSubtensors)
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010085{
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +000086 armnn::IBackendInternal::IMemoryManagerSharedPtr srcMemoryManager =
87 WorkloadFactoryHelper<SrcWorkloadFactory>::GetMemoryManager();
88
89 armnn::IBackendInternal::IMemoryManagerSharedPtr dstMemoryManager =
90 WorkloadFactoryHelper<DstWorkloadFactory>::GetMemoryManager();
91
92 SrcWorkloadFactory srcWorkloadFactory = WorkloadFactoryHelper<SrcWorkloadFactory>::GetFactory(srcMemoryManager);
93 DstWorkloadFactory dstWorkloadFactory = WorkloadFactoryHelper<DstWorkloadFactory>::GetFactory(dstMemoryManager);
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010094
Matteo Martincigh49124022019-01-11 13:25:59 +000095 return MemCopyTest<dataType>(srcWorkloadFactory, dstWorkloadFactory, withSubtensors);
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010096}
97
98} // anonymous namespace