blob: e5378bf74ae903370dbfc955fa56a4e3e9dfe100 [file] [log] [blame]
Colm Donelan17948b52022-02-01 23:37:04 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Colm Donelana98e79a2022-12-06 21:32:29 +00007#include <armnn/Deprecated.hpp>
8#include <armnn/Descriptors.hpp>
9#include <armnn/Exceptions.hpp>
10#include <armnn/IRuntime.hpp>
11#include <armnn/Optional.hpp>
12#include <armnn/Tensor.hpp>
13#include <armnn/Types.hpp>
Colm Donelan17948b52022-02-01 23:37:04 +000014#include <armnn/backends/IBackendInternal.hpp>
15#include <armnn/backends/MemCopyWorkload.hpp>
Colm Donelana98e79a2022-12-06 21:32:29 +000016#include <armnn/backends/ITensorHandle.hpp>
17#include <armnn/backends/IWorkload.hpp>
18#include <armnn/backends/OptimizationViews.hpp>
19#include <armnn/backends/SubgraphView.hpp>
20#include <armnn/backends/WorkloadData.hpp>
21#include <armnn/backends/WorkloadFactory.hpp>
22#include <armnn/backends/WorkloadInfo.hpp>
23#include <armnn/utility/IgnoreUnused.hpp>
24#include <armnn/utility/PolymorphicDowncast.hpp>
Colm Donelan17948b52022-02-01 23:37:04 +000025#include <armnnTestUtils/MockTensorHandle.hpp>
Cathal Corbett3464ba12022-03-04 11:36:39 +000026#include <backendsCommon/LayerSupportBase.hpp>
Colm Donelan17948b52022-02-01 23:37:04 +000027
Colm Donelana98e79a2022-12-06 21:32:29 +000028#include <client/include/CounterValue.hpp>
29#include <client/include/ISendTimelinePacket.hpp>
30#include <client/include/Timestamp.hpp>
31#include <client/include/backends/IBackendProfiling.hpp>
32#include <client/include/backends/IBackendProfilingContext.hpp>
33#include <common/include/Optional.hpp>
34
35#include <atomic>
36#include <cstdint>
37#include <memory>
38#include <string>
39#include <utility>
40#include <vector>
41
Colm Donelan17948b52022-02-01 23:37:04 +000042namespace armnn
43{
Colm Donelana98e79a2022-12-06 21:32:29 +000044class BackendId;
45class ICustomAllocator;
46class MockMemoryManager;
47struct LstmInputParamsInfo;
48struct QuantizedLstmInputParamsInfo;
Colm Donelan17948b52022-02-01 23:37:04 +000049
50// A bare bones Mock backend to enable unit testing of simple tensor manipulation features.
51class MockBackend : public IBackendInternal
52{
53public:
54 MockBackend() = default;
55
56 ~MockBackend() = default;
57
58 static const BackendId& GetIdStatic();
59
60 const BackendId& GetId() const override
61 {
62 return GetIdStatic();
63 }
64 IBackendInternal::IWorkloadFactoryPtr
Cathal Corbett3464ba12022-03-04 11:36:39 +000065 CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr) const override;
Colm Donelan17948b52022-02-01 23:37:04 +000066
Cathal Corbett3464ba12022-03-04 11:36:39 +000067 IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override;
68
69 IBackendInternal::IMemoryManagerUniquePtr CreateMemoryManager() const override;
70
71 IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override;
72 IBackendInternal::IBackendProfilingContextPtr
73 CreateBackendProfilingContext(const IRuntime::CreationOptions& creationOptions,
74 IBackendProfilingPtr& backendProfiling) override;
75
76 OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph) const override;
77
78 std::unique_ptr<ICustomAllocator> GetDefaultAllocator() const override;
Colm Donelan17948b52022-02-01 23:37:04 +000079};
80
81class MockWorkloadFactory : public IWorkloadFactory
82{
83
84public:
85 explicit MockWorkloadFactory(const std::shared_ptr<MockMemoryManager>& memoryManager);
86 MockWorkloadFactory();
87
88 ~MockWorkloadFactory()
89 {}
90
91 const BackendId& GetBackendId() const override;
92
93 bool SupportsSubTensors() const override
94 {
95 return false;
96 }
97
98 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateSubTensorHandle instead")
99 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle&,
100 TensorShape const&,
101 unsigned int const*) const override
102 {
103 return nullptr;
104 }
105
106 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
107 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
108 const bool IsMemoryManaged = true) const override
109 {
110 IgnoreUnused(IsMemoryManaged);
111 return std::make_unique<MockTensorHandle>(tensorInfo, m_MemoryManager);
112 };
113
114 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
115 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
116 DataLayout dataLayout,
117 const bool IsMemoryManaged = true) const override
118 {
119 IgnoreUnused(dataLayout, IsMemoryManaged);
120 return std::make_unique<MockTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
121 };
122
123 ARMNN_DEPRECATED_MSG_REMOVAL_DATE(
124 "Use ABI stable "
125 "CreateWorkload(LayerType, const QueueDescriptor&, const WorkloadInfo& info) instead.",
Jim Flynn5a3d2002022-07-31 18:00:31 +0100126 "23.08")
Colm Donelan17948b52022-02-01 23:37:04 +0000127 std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor,
128 const WorkloadInfo& info) const override
129 {
130 if (info.m_InputTensorInfos.empty())
131 {
132 throw InvalidArgumentException("MockWorkloadFactory::CreateInput: Input cannot be zero length");
133 }
134 if (info.m_OutputTensorInfos.empty())
135 {
136 throw InvalidArgumentException("MockWorkloadFactory::CreateInput: Output cannot be zero length");
137 }
138
139 if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
140 {
141 throw InvalidArgumentException(
142 "MockWorkloadFactory::CreateInput: data input and output differ in byte count.");
143 }
144
145 return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
146 };
147
148 std::unique_ptr<IWorkload>
149 CreateWorkload(LayerType type, const QueueDescriptor& descriptor, const WorkloadInfo& info) const override;
150
151private:
152 mutable std::shared_ptr<MockMemoryManager> m_MemoryManager;
153};
154
Cathal Corbett3464ba12022-03-04 11:36:39 +0000155class MockBackendInitialiser
156{
157public:
158 MockBackendInitialiser();
159 ~MockBackendInitialiser();
160};
161
162class MockBackendProfilingContext : public arm::pipe::IBackendProfilingContext
163{
164public:
165 MockBackendProfilingContext(IBackendInternal::IBackendProfilingPtr& backendProfiling)
166 : m_BackendProfiling(std::move(backendProfiling))
167 , m_CapturePeriod(0)
168 , m_IsTimelineEnabled(true)
169 {}
170
171 ~MockBackendProfilingContext() = default;
172
173 IBackendInternal::IBackendProfilingPtr& GetBackendProfiling()
174 {
175 return m_BackendProfiling;
176 }
177
178 uint16_t RegisterCounters(uint16_t currentMaxGlobalCounterId)
179 {
180 std::unique_ptr<arm::pipe::IRegisterBackendCounters> counterRegistrar =
181 m_BackendProfiling->GetCounterRegistrationInterface(static_cast<uint16_t>(currentMaxGlobalCounterId));
182
183 std::string categoryName("MockCounters");
184 counterRegistrar->RegisterCategory(categoryName);
185
186 counterRegistrar->RegisterCounter(0, categoryName, 0, 0, 1.f, "Mock Counter One", "Some notional counter");
187
188 counterRegistrar->RegisterCounter(1, categoryName, 0, 0, 1.f, "Mock Counter Two",
189 "Another notional counter");
190
191 std::string units("microseconds");
192 uint16_t nextMaxGlobalCounterId =
193 counterRegistrar->RegisterCounter(2, categoryName, 0, 0, 1.f, "Mock MultiCore Counter",
194 "A dummy four core counter", units, 4);
195 return nextMaxGlobalCounterId;
196 }
197
Jim Flynndecd08b2022-03-13 22:35:46 +0000198 arm::pipe::Optional<std::string> ActivateCounters(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
Cathal Corbett3464ba12022-03-04 11:36:39 +0000199 {
200 if (capturePeriod == 0 || counterIds.size() == 0)
201 {
202 m_ActiveCounters.clear();
203 }
204 else if (capturePeriod == 15939u)
205 {
Jim Flynndecd08b2022-03-13 22:35:46 +0000206 return arm::pipe::Optional<std::string>("ActivateCounters example test error");
Cathal Corbett3464ba12022-03-04 11:36:39 +0000207 }
208 m_CapturePeriod = capturePeriod;
209 m_ActiveCounters = counterIds;
Jim Flynndecd08b2022-03-13 22:35:46 +0000210 return arm::pipe::Optional<std::string>();
Cathal Corbett3464ba12022-03-04 11:36:39 +0000211 }
212
213 std::vector<arm::pipe::Timestamp> ReportCounterValues()
214 {
215 std::vector<arm::pipe::CounterValue> counterValues;
216
217 for (auto counterId : m_ActiveCounters)
218 {
219 counterValues.emplace_back(arm::pipe::CounterValue{ counterId, counterId + 1u });
220 }
221
222 uint64_t timestamp = m_CapturePeriod;
223 return { arm::pipe::Timestamp{ timestamp, counterValues } };
224 }
225
226 bool EnableProfiling(bool)
227 {
228 auto sendTimelinePacket = m_BackendProfiling->GetSendTimelinePacket();
229 sendTimelinePacket->SendTimelineEntityBinaryPacket(4256);
230 sendTimelinePacket->Commit();
231 return true;
232 }
233
234 bool EnableTimelineReporting(bool isEnabled)
235 {
236 m_IsTimelineEnabled = isEnabled;
237 return isEnabled;
238 }
239
240 bool TimelineReportingEnabled()
241 {
242 return m_IsTimelineEnabled;
243 }
244
245private:
246 IBackendInternal::IBackendProfilingPtr m_BackendProfiling;
247 uint32_t m_CapturePeriod;
248 std::vector<uint16_t> m_ActiveCounters;
249 std::atomic<bool> m_IsTimelineEnabled;
250};
251
252class MockBackendProfilingService
253{
254public:
255 // Getter for the singleton instance
256 static MockBackendProfilingService& Instance()
257 {
258 static MockBackendProfilingService instance;
259 return instance;
260 }
261
262 MockBackendProfilingContext* GetContext()
263 {
264 return m_sharedContext.get();
265 }
266
267 void SetProfilingContextPtr(std::shared_ptr<MockBackendProfilingContext> shared)
268 {
269 m_sharedContext = shared;
270 }
271
272private:
273 std::shared_ptr<MockBackendProfilingContext> m_sharedContext;
274};
275
276class MockLayerSupport : public LayerSupportBase
277{
278public:
279 bool IsLayerSupported(const LayerType& type,
280 const std::vector<TensorInfo>& infos,
281 const BaseDescriptor& descriptor,
282 const Optional<LstmInputParamsInfo>& /*lstmParamsInfo*/,
283 const Optional<QuantizedLstmInputParamsInfo>& /*quantizedLstmParamsInfo*/,
284 Optional<std::string&> reasonIfUnsupported) const override
285 {
286 switch(type)
287 {
288 case LayerType::Input:
289 return IsInputSupported(infos[0], reasonIfUnsupported);
290 case LayerType::Output:
291 return IsOutputSupported(infos[0], reasonIfUnsupported);
292 case LayerType::Addition:
293 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
294 case LayerType::Convolution2d:
295 {
296 if (infos.size() != 4)
297 {
298 throw InvalidArgumentException("Invalid number of TransposeConvolution2d "
299 "TensorInfos. TensorInfos should be of format: "
300 "{input, output, weights, biases}.");
301 }
302
303 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
304 if (infos[3] == TensorInfo())
305 {
306 return IsConvolution2dSupported(infos[0],
307 infos[1],
308 desc,
309 infos[2],
310 EmptyOptional(),
311 reasonIfUnsupported);
312 }
313 else
314 {
315 return IsConvolution2dSupported(infos[0],
316 infos[1],
317 desc,
318 infos[2],
319 infos[3],
320 reasonIfUnsupported);
321 }
322 }
323 default:
324 return false;
325 }
326 }
327
328 bool IsInputSupported(const TensorInfo& /*input*/,
329 Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
330 {
331 return true;
332 }
333
334 bool IsOutputSupported(const TensorInfo& /*input*/,
335 Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
336 {
337 return true;
338 }
339
340 bool IsAdditionSupported(const TensorInfo& /*input0*/,
341 const TensorInfo& /*input1*/,
342 const TensorInfo& /*output*/,
343 Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
344 {
345 return true;
346 }
347
348 bool IsConvolution2dSupported(const TensorInfo& /*input*/,
349 const TensorInfo& /*output*/,
350 const Convolution2dDescriptor& /*descriptor*/,
351 const TensorInfo& /*weights*/,
352 const Optional<TensorInfo>& /*biases*/,
353 Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
354 {
355 return true;
356 }
357};
358
Colm Donelan17948b52022-02-01 23:37:04 +0000359} // namespace armnn