blob: 1856dcb056d0fc320078718f878f85f98a199bed [file] [log] [blame]
Colm Donelanc42a9872022-02-02 16:35:09 +00001//
2// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "LayerTestResult.hpp"
8#include "TensorCopyUtils.hpp"
9#include "TensorHelpers.hpp"
10#include "WorkloadTestUtils.hpp"
11#include <ResolveType.hpp>
12#include <armnn/backends/IBackendInternal.hpp>
13#include <armnnTestUtils/MockBackend.hpp>
14
15namespace
16{
17
18template<armnn::DataType dataType, typename T = armnn::ResolveType<dataType>>
19LayerTestResult<T, 4> MemCopyTest(armnn::IWorkloadFactory& srcWorkloadFactory,
20 armnn::IWorkloadFactory& dstWorkloadFactory,
21 bool withSubtensors)
22{
23 const std::array<unsigned int, 4> shapeData = { { 1u, 1u, 6u, 5u } };
24 const armnn::TensorShape tensorShape(4, shapeData.data());
25 const armnn::TensorInfo tensorInfo(tensorShape, dataType);
26 std::vector<T> inputData =
27 {
28 1, 2, 3, 4, 5,
29 6, 7, 8, 9, 10,
30 11, 12, 13, 14, 15,
31 16, 17, 18, 19, 20,
32 21, 22, 23, 24, 25,
33 26, 27, 28, 29, 30,
34 };
35
36 LayerTestResult<T, 4> ret(tensorInfo);
37 ret.m_ExpectedData = inputData;
38
39 std::vector<T> actualOutput(tensorInfo.GetNumElements());
40
41 ARMNN_NO_DEPRECATE_WARN_BEGIN
42 auto inputTensorHandle = srcWorkloadFactory.CreateTensorHandle(tensorInfo);
43 auto outputTensorHandle = dstWorkloadFactory.CreateTensorHandle(tensorInfo);
44 ARMNN_NO_DEPRECATE_WARN_END
45
46 AllocateAndCopyDataToITensorHandle(inputTensorHandle.get(), inputData.data());
47 outputTensorHandle->Allocate();
48
49 armnn::MemCopyQueueDescriptor memCopyQueueDesc;
50 armnn::WorkloadInfo workloadInfo;
51
52 const unsigned int origin[4] = {};
53
54 ARMNN_NO_DEPRECATE_WARN_BEGIN
55 auto workloadInput = (withSubtensors && srcWorkloadFactory.SupportsSubTensors())
56 ? srcWorkloadFactory.CreateSubTensorHandle(*inputTensorHandle, tensorShape, origin)
57 : std::move(inputTensorHandle);
58 auto workloadOutput = (withSubtensors && dstWorkloadFactory.SupportsSubTensors())
59 ? dstWorkloadFactory.CreateSubTensorHandle(*outputTensorHandle, tensorShape, origin)
60 : std::move(outputTensorHandle);
61 ARMNN_NO_DEPRECATE_WARN_END
62
63 AddInputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadInput.get());
64 AddOutputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadOutput.get());
65
66 dstWorkloadFactory.CreateWorkload(armnn::LayerType::MemCopy, memCopyQueueDesc, workloadInfo)->Execute();
67
68 CopyDataFromITensorHandle(actualOutput.data(), workloadOutput.get());
69 ret.m_ActualData = actualOutput;
70
71 return ret;
72}
73
74template <typename WorkloadFactoryType>
75struct MemCopyTestHelper
76{};
77template <>
78struct MemCopyTestHelper<armnn::MockWorkloadFactory>
79{
80 static armnn::IBackendInternal::IMemoryManagerSharedPtr GetMemoryManager()
81 {
82 armnn::MockBackend backend;
83 return backend.CreateMemoryManager();
84 }
85
86 static armnn::MockWorkloadFactory
87 GetFactory(const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr)
88 {
89 IgnoreUnused(memoryManager);
90 return armnn::MockWorkloadFactory();
91 }
92};
93
94using MockMemCopyTestHelper = MemCopyTestHelper<armnn::MockWorkloadFactory>;
95
96template <typename SrcWorkloadFactory,
97 typename DstWorkloadFactory,
98 armnn::DataType dataType,
99 typename T = armnn::ResolveType<dataType>>
100LayerTestResult<T, 4> MemCopyTest(bool withSubtensors)
101{
102
103 armnn::IBackendInternal::IMemoryManagerSharedPtr srcMemoryManager =
104 MemCopyTestHelper<SrcWorkloadFactory>::GetMemoryManager();
105
106 armnn::IBackendInternal::IMemoryManagerSharedPtr dstMemoryManager =
107 MemCopyTestHelper<DstWorkloadFactory>::GetMemoryManager();
108
109 SrcWorkloadFactory srcWorkloadFactory = MemCopyTestHelper<SrcWorkloadFactory>::GetFactory(srcMemoryManager);
110 DstWorkloadFactory dstWorkloadFactory = MemCopyTestHelper<DstWorkloadFactory>::GetFactory(dstMemoryManager);
111
112 return MemCopyTest<dataType>(srcWorkloadFactory, dstWorkloadFactory, withSubtensors);
113}
114
115} // anonymous namespace