blob: 266c049c2f8e1f57862fa868d6bef654c25ceec6 [file] [log] [blame]
Colm Donelan17948b52022-02-01 23:37:04 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Cathal Corbett3464ba12022-03-04 11:36:39 +00007#include <atomic>
8
Colm Donelan17948b52022-02-01 23:37:04 +00009#include <armnn/backends/IBackendInternal.hpp>
10#include <armnn/backends/MemCopyWorkload.hpp>
11#include <armnnTestUtils/MockTensorHandle.hpp>
Cathal Corbett3464ba12022-03-04 11:36:39 +000012#include <backendsCommon/LayerSupportBase.hpp>
Colm Donelan17948b52022-02-01 23:37:04 +000013
14namespace armnn
15{
16
17// A bare bones Mock backend to enable unit testing of simple tensor manipulation features.
18class MockBackend : public IBackendInternal
19{
20public:
21 MockBackend() = default;
22
23 ~MockBackend() = default;
24
25 static const BackendId& GetIdStatic();
26
27 const BackendId& GetId() const override
28 {
29 return GetIdStatic();
30 }
31 IBackendInternal::IWorkloadFactoryPtr
Cathal Corbett3464ba12022-03-04 11:36:39 +000032 CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr) const override;
Colm Donelan17948b52022-02-01 23:37:04 +000033
Cathal Corbett3464ba12022-03-04 11:36:39 +000034 IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override;
35
36 IBackendInternal::IMemoryManagerUniquePtr CreateMemoryManager() const override;
37
38 IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override;
39 IBackendInternal::IBackendProfilingContextPtr
40 CreateBackendProfilingContext(const IRuntime::CreationOptions& creationOptions,
41 IBackendProfilingPtr& backendProfiling) override;
42
43 OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph) const override;
44
45 std::unique_ptr<ICustomAllocator> GetDefaultAllocator() const override;
Colm Donelan17948b52022-02-01 23:37:04 +000046};
47
48class MockWorkloadFactory : public IWorkloadFactory
49{
50
51public:
52 explicit MockWorkloadFactory(const std::shared_ptr<MockMemoryManager>& memoryManager);
53 MockWorkloadFactory();
54
55 ~MockWorkloadFactory()
56 {}
57
58 const BackendId& GetBackendId() const override;
59
60 bool SupportsSubTensors() const override
61 {
62 return false;
63 }
64
65 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateSubTensorHandle instead")
66 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle&,
67 TensorShape const&,
68 unsigned int const*) const override
69 {
70 return nullptr;
71 }
72
73 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
74 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
75 const bool IsMemoryManaged = true) const override
76 {
77 IgnoreUnused(IsMemoryManaged);
78 return std::make_unique<MockTensorHandle>(tensorInfo, m_MemoryManager);
79 };
80
81 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
82 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
83 DataLayout dataLayout,
84 const bool IsMemoryManaged = true) const override
85 {
86 IgnoreUnused(dataLayout, IsMemoryManaged);
87 return std::make_unique<MockTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
88 };
89
90 ARMNN_DEPRECATED_MSG_REMOVAL_DATE(
91 "Use ABI stable "
92 "CreateWorkload(LayerType, const QueueDescriptor&, const WorkloadInfo& info) instead.",
93 "22.11")
94 std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor,
95 const WorkloadInfo& info) const override
96 {
97 if (info.m_InputTensorInfos.empty())
98 {
99 throw InvalidArgumentException("MockWorkloadFactory::CreateInput: Input cannot be zero length");
100 }
101 if (info.m_OutputTensorInfos.empty())
102 {
103 throw InvalidArgumentException("MockWorkloadFactory::CreateInput: Output cannot be zero length");
104 }
105
106 if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
107 {
108 throw InvalidArgumentException(
109 "MockWorkloadFactory::CreateInput: data input and output differ in byte count.");
110 }
111
112 return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
113 };
114
115 std::unique_ptr<IWorkload>
116 CreateWorkload(LayerType type, const QueueDescriptor& descriptor, const WorkloadInfo& info) const override;
117
118private:
119 mutable std::shared_ptr<MockMemoryManager> m_MemoryManager;
120};
121
Cathal Corbett3464ba12022-03-04 11:36:39 +0000122class MockBackendInitialiser
123{
124public:
125 MockBackendInitialiser();
126 ~MockBackendInitialiser();
127};
128
129class MockBackendProfilingContext : public arm::pipe::IBackendProfilingContext
130{
131public:
132 MockBackendProfilingContext(IBackendInternal::IBackendProfilingPtr& backendProfiling)
133 : m_BackendProfiling(std::move(backendProfiling))
134 , m_CapturePeriod(0)
135 , m_IsTimelineEnabled(true)
136 {}
137
138 ~MockBackendProfilingContext() = default;
139
140 IBackendInternal::IBackendProfilingPtr& GetBackendProfiling()
141 {
142 return m_BackendProfiling;
143 }
144
145 uint16_t RegisterCounters(uint16_t currentMaxGlobalCounterId)
146 {
147 std::unique_ptr<arm::pipe::IRegisterBackendCounters> counterRegistrar =
148 m_BackendProfiling->GetCounterRegistrationInterface(static_cast<uint16_t>(currentMaxGlobalCounterId));
149
150 std::string categoryName("MockCounters");
151 counterRegistrar->RegisterCategory(categoryName);
152
153 counterRegistrar->RegisterCounter(0, categoryName, 0, 0, 1.f, "Mock Counter One", "Some notional counter");
154
155 counterRegistrar->RegisterCounter(1, categoryName, 0, 0, 1.f, "Mock Counter Two",
156 "Another notional counter");
157
158 std::string units("microseconds");
159 uint16_t nextMaxGlobalCounterId =
160 counterRegistrar->RegisterCounter(2, categoryName, 0, 0, 1.f, "Mock MultiCore Counter",
161 "A dummy four core counter", units, 4);
162 return nextMaxGlobalCounterId;
163 }
164
Jim Flynndecd08b2022-03-13 22:35:46 +0000165 arm::pipe::Optional<std::string> ActivateCounters(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
Cathal Corbett3464ba12022-03-04 11:36:39 +0000166 {
167 if (capturePeriod == 0 || counterIds.size() == 0)
168 {
169 m_ActiveCounters.clear();
170 }
171 else if (capturePeriod == 15939u)
172 {
Jim Flynndecd08b2022-03-13 22:35:46 +0000173 return arm::pipe::Optional<std::string>("ActivateCounters example test error");
Cathal Corbett3464ba12022-03-04 11:36:39 +0000174 }
175 m_CapturePeriod = capturePeriod;
176 m_ActiveCounters = counterIds;
Jim Flynndecd08b2022-03-13 22:35:46 +0000177 return arm::pipe::Optional<std::string>();
Cathal Corbett3464ba12022-03-04 11:36:39 +0000178 }
179
180 std::vector<arm::pipe::Timestamp> ReportCounterValues()
181 {
182 std::vector<arm::pipe::CounterValue> counterValues;
183
184 for (auto counterId : m_ActiveCounters)
185 {
186 counterValues.emplace_back(arm::pipe::CounterValue{ counterId, counterId + 1u });
187 }
188
189 uint64_t timestamp = m_CapturePeriod;
190 return { arm::pipe::Timestamp{ timestamp, counterValues } };
191 }
192
193 bool EnableProfiling(bool)
194 {
195 auto sendTimelinePacket = m_BackendProfiling->GetSendTimelinePacket();
196 sendTimelinePacket->SendTimelineEntityBinaryPacket(4256);
197 sendTimelinePacket->Commit();
198 return true;
199 }
200
201 bool EnableTimelineReporting(bool isEnabled)
202 {
203 m_IsTimelineEnabled = isEnabled;
204 return isEnabled;
205 }
206
207 bool TimelineReportingEnabled()
208 {
209 return m_IsTimelineEnabled;
210 }
211
212private:
213 IBackendInternal::IBackendProfilingPtr m_BackendProfiling;
214 uint32_t m_CapturePeriod;
215 std::vector<uint16_t> m_ActiveCounters;
216 std::atomic<bool> m_IsTimelineEnabled;
217};
218
219class MockBackendProfilingService
220{
221public:
222 // Getter for the singleton instance
223 static MockBackendProfilingService& Instance()
224 {
225 static MockBackendProfilingService instance;
226 return instance;
227 }
228
229 MockBackendProfilingContext* GetContext()
230 {
231 return m_sharedContext.get();
232 }
233
234 void SetProfilingContextPtr(std::shared_ptr<MockBackendProfilingContext> shared)
235 {
236 m_sharedContext = shared;
237 }
238
239private:
240 std::shared_ptr<MockBackendProfilingContext> m_sharedContext;
241};
242
243class MockLayerSupport : public LayerSupportBase
244{
245public:
246 bool IsLayerSupported(const LayerType& type,
247 const std::vector<TensorInfo>& infos,
248 const BaseDescriptor& descriptor,
249 const Optional<LstmInputParamsInfo>& /*lstmParamsInfo*/,
250 const Optional<QuantizedLstmInputParamsInfo>& /*quantizedLstmParamsInfo*/,
251 Optional<std::string&> reasonIfUnsupported) const override
252 {
253 switch(type)
254 {
255 case LayerType::Input:
256 return IsInputSupported(infos[0], reasonIfUnsupported);
257 case LayerType::Output:
258 return IsOutputSupported(infos[0], reasonIfUnsupported);
259 case LayerType::Addition:
260 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
261 case LayerType::Convolution2d:
262 {
263 if (infos.size() != 4)
264 {
265 throw InvalidArgumentException("Invalid number of TransposeConvolution2d "
266 "TensorInfos. TensorInfos should be of format: "
267 "{input, output, weights, biases}.");
268 }
269
270 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
271 if (infos[3] == TensorInfo())
272 {
273 return IsConvolution2dSupported(infos[0],
274 infos[1],
275 desc,
276 infos[2],
277 EmptyOptional(),
278 reasonIfUnsupported);
279 }
280 else
281 {
282 return IsConvolution2dSupported(infos[0],
283 infos[1],
284 desc,
285 infos[2],
286 infos[3],
287 reasonIfUnsupported);
288 }
289 }
290 default:
291 return false;
292 }
293 }
294
295 bool IsInputSupported(const TensorInfo& /*input*/,
296 Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
297 {
298 return true;
299 }
300
301 bool IsOutputSupported(const TensorInfo& /*input*/,
302 Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
303 {
304 return true;
305 }
306
307 bool IsAdditionSupported(const TensorInfo& /*input0*/,
308 const TensorInfo& /*input1*/,
309 const TensorInfo& /*output*/,
310 Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
311 {
312 return true;
313 }
314
315 bool IsConvolution2dSupported(const TensorInfo& /*input*/,
316 const TensorInfo& /*output*/,
317 const Convolution2dDescriptor& /*descriptor*/,
318 const TensorInfo& /*weights*/,
319 const Optional<TensorInfo>& /*biases*/,
320 Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
321 {
322 return true;
323 }
324};
325
Colm Donelan17948b52022-02-01 23:37:04 +0000326} // namespace armnn