blob: 91ba4eae1725a292219da7c3c133430ba0671843 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +01002// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +01007#include <ResolveType.hpp>
Matteo Martincigh49124022019-01-11 13:25:59 +00008
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00009#include <armnn/backends/IBackendInternal.hpp>
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +000010
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000011#include <backendsCommon/test/LayerTests.hpp>
12#include <backendsCommon/test/TensorCopyUtils.hpp>
Aron Virginas-Tar56055192018-11-12 18:10:43 +000013#include <backendsCommon/test/WorkloadFactoryHelper.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014#include <backendsCommon/test/WorkloadTestUtils.hpp>
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010015
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000016#include <test/TensorHelpers.hpp>
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010017
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010018namespace
19{
20
Matteo Martincigh49124022019-01-11 13:25:59 +000021template<armnn::DataType dataType, typename T = armnn::ResolveType<dataType>>
22LayerTestResult<T, 4> MemCopyTest(armnn::IWorkloadFactory& srcWorkloadFactory,
23 armnn::IWorkloadFactory& dstWorkloadFactory,
24 bool withSubtensors)
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010025{
26 const std::array<unsigned int, 4> shapeData = { { 1u, 1u, 6u, 5u } };
27 const armnn::TensorShape tensorShape(4, shapeData.data());
Matteo Martincigh49124022019-01-11 13:25:59 +000028 const armnn::TensorInfo tensorInfo(tensorShape, dataType);
Sadik Armagan483c8112021-06-01 09:24:52 +010029 std::vector<T> inputData =
30 {
31 1, 2, 3, 4, 5,
32 6, 7, 8, 9, 10,
33 11, 12, 13, 14, 15,
34 16, 17, 18, 19, 20,
35 21, 22, 23, 24, 25,
36 26, 27, 28, 29, 30,
37 };
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010038
Matteo Martincigh49124022019-01-11 13:25:59 +000039 LayerTestResult<T, 4> ret(tensorInfo);
Sadik Armagan483c8112021-06-01 09:24:52 +010040 ret.m_ExpectedData = inputData;
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010041
Sadik Armagan483c8112021-06-01 09:24:52 +010042 std::vector<T> actualOutput(tensorInfo.GetNumElements());
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010043
Teresa Charlincf2d9132020-08-17 20:06:26 +010044 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010045 auto inputTensorHandle = srcWorkloadFactory.CreateTensorHandle(tensorInfo);
46 auto outputTensorHandle = dstWorkloadFactory.CreateTensorHandle(tensorInfo);
Teresa Charlincf2d9132020-08-17 20:06:26 +010047 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010048
49 AllocateAndCopyDataToITensorHandle(inputTensorHandle.get(), inputData.data());
50 outputTensorHandle->Allocate();
51
52 armnn::MemCopyQueueDescriptor memCopyQueueDesc;
53 armnn::WorkloadInfo workloadInfo;
54
55 const unsigned int origin[4] = {};
56
Teresa Charlincf2d9132020-08-17 20:06:26 +010057 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010058 auto workloadInput = (withSubtensors && srcWorkloadFactory.SupportsSubTensors())
59 ? srcWorkloadFactory.CreateSubTensorHandle(*inputTensorHandle, tensorShape, origin)
60 : std::move(inputTensorHandle);
61 auto workloadOutput = (withSubtensors && dstWorkloadFactory.SupportsSubTensors())
62 ? dstWorkloadFactory.CreateSubTensorHandle(*outputTensorHandle, tensorShape, origin)
63 : std::move(outputTensorHandle);
Teresa Charlincf2d9132020-08-17 20:06:26 +010064 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010065
66 AddInputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadInput.get());
67 AddOutputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadOutput.get());
68
69 dstWorkloadFactory.CreateMemCopy(memCopyQueueDesc, workloadInfo)->Execute();
70
Sadik Armagan483c8112021-06-01 09:24:52 +010071 CopyDataFromITensorHandle(actualOutput.data(), workloadOutput.get());
72 ret.m_ActualData = actualOutput;
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010073
74 return ret;
75}
76
Matteo Martincigh49124022019-01-11 13:25:59 +000077template<typename SrcWorkloadFactory,
78 typename DstWorkloadFactory,
79 armnn::DataType dataType,
80 typename T = armnn::ResolveType<dataType>>
81LayerTestResult<T, 4> MemCopyTest(bool withSubtensors)
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010082{
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +000083 armnn::IBackendInternal::IMemoryManagerSharedPtr srcMemoryManager =
84 WorkloadFactoryHelper<SrcWorkloadFactory>::GetMemoryManager();
85
86 armnn::IBackendInternal::IMemoryManagerSharedPtr dstMemoryManager =
87 WorkloadFactoryHelper<DstWorkloadFactory>::GetMemoryManager();
88
89 SrcWorkloadFactory srcWorkloadFactory = WorkloadFactoryHelper<SrcWorkloadFactory>::GetFactory(srcMemoryManager);
90 DstWorkloadFactory dstWorkloadFactory = WorkloadFactoryHelper<DstWorkloadFactory>::GetFactory(dstMemoryManager);
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010091
Matteo Martincigh49124022019-01-11 13:25:59 +000092 return MemCopyTest<dataType>(srcWorkloadFactory, dstWorkloadFactory, withSubtensors);
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010093}
94
95} // anonymous namespace