blob: 8bc41b3f3f57ff6f530f5771537685ea2706a38e [file] [log] [blame]
Colm Donelan17948b52022-02-01 23:37:04 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include <armnn/backends/IBackendInternal.hpp>
8#include <armnn/backends/MemCopyWorkload.hpp>
9#include <armnnTestUtils/MockTensorHandle.hpp>
10
11namespace armnn
12{
13
14// A bare bones Mock backend to enable unit testing of simple tensor manipulation features.
15class MockBackend : public IBackendInternal
16{
17public:
18 MockBackend() = default;
19
20 ~MockBackend() = default;
21
22 static const BackendId& GetIdStatic();
23
24 const BackendId& GetId() const override
25 {
26 return GetIdStatic();
27 }
28 IBackendInternal::IWorkloadFactoryPtr
29 CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr) const override
30 {
31 IgnoreUnused(memoryManager);
32 return nullptr;
33 }
34
35 IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
36 {
37 return nullptr;
38 };
39};
40
41class MockWorkloadFactory : public IWorkloadFactory
42{
43
44public:
45 explicit MockWorkloadFactory(const std::shared_ptr<MockMemoryManager>& memoryManager);
46 MockWorkloadFactory();
47
48 ~MockWorkloadFactory()
49 {}
50
51 const BackendId& GetBackendId() const override;
52
53 bool SupportsSubTensors() const override
54 {
55 return false;
56 }
57
58 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateSubTensorHandle instead")
59 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle&,
60 TensorShape const&,
61 unsigned int const*) const override
62 {
63 return nullptr;
64 }
65
66 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
67 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
68 const bool IsMemoryManaged = true) const override
69 {
70 IgnoreUnused(IsMemoryManaged);
71 return std::make_unique<MockTensorHandle>(tensorInfo, m_MemoryManager);
72 };
73
74 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
75 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
76 DataLayout dataLayout,
77 const bool IsMemoryManaged = true) const override
78 {
79 IgnoreUnused(dataLayout, IsMemoryManaged);
80 return std::make_unique<MockTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
81 };
82
83 ARMNN_DEPRECATED_MSG_REMOVAL_DATE(
84 "Use ABI stable "
85 "CreateWorkload(LayerType, const QueueDescriptor&, const WorkloadInfo& info) instead.",
86 "22.11")
87 std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor,
88 const WorkloadInfo& info) const override
89 {
90 if (info.m_InputTensorInfos.empty())
91 {
92 throw InvalidArgumentException("MockWorkloadFactory::CreateInput: Input cannot be zero length");
93 }
94 if (info.m_OutputTensorInfos.empty())
95 {
96 throw InvalidArgumentException("MockWorkloadFactory::CreateInput: Output cannot be zero length");
97 }
98
99 if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
100 {
101 throw InvalidArgumentException(
102 "MockWorkloadFactory::CreateInput: data input and output differ in byte count.");
103 }
104
105 return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
106 };
107
108 std::unique_ptr<IWorkload>
109 CreateWorkload(LayerType type, const QueueDescriptor& descriptor, const WorkloadInfo& info) const override;
110
111private:
112 mutable std::shared_ptr<MockMemoryManager> m_MemoryManager;
113};
114
115} // namespace armnn