blob: 956ea27c154794438d9becefa556d333cc356506 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +01002// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +01007#include <ResolveType.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00008#include <armnn/backends/IBackendInternal.hpp>
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +00009
Sadik Armagana097d2a2021-11-24 15:47:28 +000010#include <test/TensorHelpers.hpp>
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010011
Sadik Armagana097d2a2021-11-24 15:47:28 +000012#include <armnnTestUtils/LayerTestResult.hpp>
13#include <armnnTestUtils/TensorCopyUtils.hpp>
Colm Donelan0c479742021-12-10 12:43:54 +000014#include <armnnTestUtils/WorkloadTestUtils.hpp>
Sadik Armagana097d2a2021-11-24 15:47:28 +000015#include <backendsCommon/test/WorkloadFactoryHelper.hpp>
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010016
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010017namespace
18{
19
Matteo Martincigh49124022019-01-11 13:25:59 +000020template<armnn::DataType dataType, typename T = armnn::ResolveType<dataType>>
21LayerTestResult<T, 4> MemCopyTest(armnn::IWorkloadFactory& srcWorkloadFactory,
22 armnn::IWorkloadFactory& dstWorkloadFactory,
23 bool withSubtensors)
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010024{
25 const std::array<unsigned int, 4> shapeData = { { 1u, 1u, 6u, 5u } };
26 const armnn::TensorShape tensorShape(4, shapeData.data());
Matteo Martincigh49124022019-01-11 13:25:59 +000027 const armnn::TensorInfo tensorInfo(tensorShape, dataType);
Sadik Armagan483c8112021-06-01 09:24:52 +010028 std::vector<T> inputData =
29 {
30 1, 2, 3, 4, 5,
31 6, 7, 8, 9, 10,
32 11, 12, 13, 14, 15,
33 16, 17, 18, 19, 20,
34 21, 22, 23, 24, 25,
35 26, 27, 28, 29, 30,
36 };
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010037
Matteo Martincigh49124022019-01-11 13:25:59 +000038 LayerTestResult<T, 4> ret(tensorInfo);
Sadik Armagan483c8112021-06-01 09:24:52 +010039 ret.m_ExpectedData = inputData;
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010040
Sadik Armagan483c8112021-06-01 09:24:52 +010041 std::vector<T> actualOutput(tensorInfo.GetNumElements());
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010042
Teresa Charlincf2d9132020-08-17 20:06:26 +010043 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010044 auto inputTensorHandle = srcWorkloadFactory.CreateTensorHandle(tensorInfo);
45 auto outputTensorHandle = dstWorkloadFactory.CreateTensorHandle(tensorInfo);
Teresa Charlincf2d9132020-08-17 20:06:26 +010046 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010047
48 AllocateAndCopyDataToITensorHandle(inputTensorHandle.get(), inputData.data());
49 outputTensorHandle->Allocate();
50
51 armnn::MemCopyQueueDescriptor memCopyQueueDesc;
52 armnn::WorkloadInfo workloadInfo;
53
54 const unsigned int origin[4] = {};
55
Teresa Charlincf2d9132020-08-17 20:06:26 +010056 ARMNN_NO_DEPRECATE_WARN_BEGIN
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010057 auto workloadInput = (withSubtensors && srcWorkloadFactory.SupportsSubTensors())
58 ? srcWorkloadFactory.CreateSubTensorHandle(*inputTensorHandle, tensorShape, origin)
59 : std::move(inputTensorHandle);
60 auto workloadOutput = (withSubtensors && dstWorkloadFactory.SupportsSubTensors())
61 ? dstWorkloadFactory.CreateSubTensorHandle(*outputTensorHandle, tensorShape, origin)
62 : std::move(outputTensorHandle);
Teresa Charlincf2d9132020-08-17 20:06:26 +010063 ARMNN_NO_DEPRECATE_WARN_END
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010064
65 AddInputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadInput.get());
66 AddOutputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadOutput.get());
67
Teresa Charlin611c7fb2022-01-07 09:47:29 +000068 dstWorkloadFactory.CreateWorkload(armnn::LayerType::MemCopy, memCopyQueueDesc, workloadInfo)->Execute();
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010069
Sadik Armagan483c8112021-06-01 09:24:52 +010070 CopyDataFromITensorHandle(actualOutput.data(), workloadOutput.get());
71 ret.m_ActualData = actualOutput;
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010072
73 return ret;
74}
75
Matteo Martincigh49124022019-01-11 13:25:59 +000076template<typename SrcWorkloadFactory,
77 typename DstWorkloadFactory,
78 armnn::DataType dataType,
79 typename T = armnn::ResolveType<dataType>>
80LayerTestResult<T, 4> MemCopyTest(bool withSubtensors)
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010081{
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +000082 armnn::IBackendInternal::IMemoryManagerSharedPtr srcMemoryManager =
83 WorkloadFactoryHelper<SrcWorkloadFactory>::GetMemoryManager();
84
85 armnn::IBackendInternal::IMemoryManagerSharedPtr dstMemoryManager =
86 WorkloadFactoryHelper<DstWorkloadFactory>::GetMemoryManager();
87
88 SrcWorkloadFactory srcWorkloadFactory = WorkloadFactoryHelper<SrcWorkloadFactory>::GetFactory(srcMemoryManager);
89 DstWorkloadFactory dstWorkloadFactory = WorkloadFactoryHelper<DstWorkloadFactory>::GetFactory(dstMemoryManager);
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010090
Matteo Martincigh49124022019-01-11 13:25:59 +000091 return MemCopyTest<dataType>(srcWorkloadFactory, dstWorkloadFactory, withSubtensors);
Aron Virginas-Tar3b278e92018-10-12 13:00:55 +010092}
93
94} // anonymous namespace