blob: 2ea3c2abf180eb69570bbaa7af9071468f67b8ea [file] [log] [blame]
Derek Lamberti84da38b2019-06-13 11:40:08 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Sadik Armagan1625efc2021-06-10 18:24:34 +01005
6#include <doctest/doctest.h>
Derek Lamberti84da38b2019-06-13 11:40:08 +01007
8#include <armnn/LayerVisitorBase.hpp>
9
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000010#include <armnn/backends/IBackendContext.hpp>
11#include <armnn/backends/IBackendInternal.hpp>
12#include <armnn/backends/IMemoryManager.hpp>
13#include <armnn/backends/ITensorHandleFactory.hpp>
Derek Lamberti84da38b2019-06-13 11:40:08 +010014#include <backendsCommon/TensorHandleFactoryRegistry.hpp>
15
16#include <optimizations/Optimization.hpp>
17
18#include <Network.hpp>
19
Jan Eilers8eb25602020-03-09 12:13:48 +000020#include <armnn/utility/IgnoreUnused.hpp>
21
Derek Lamberti84da38b2019-06-13 11:40:08 +010022#include <vector>
23#include <string>
24
Francis Murtaghb4f312c2019-12-31 12:44:20 +000025
Derek Lamberti84da38b2019-06-13 11:40:08 +010026using namespace armnn;
27
28class TestMemMgr : public IMemoryManager
29{
30public:
31 TestMemMgr() = default;
32
33 void Acquire() override {}
34 void Release() override {}
35};
36
37class TestFactory1 : public ITensorHandleFactory
38{
39public:
40 TestFactory1(std::weak_ptr<IMemoryManager> mgr, ITensorHandleFactory::FactoryId id)
41 : m_Id(id)
42 , m_MemMgr(mgr)
43 {}
44
45 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
46 TensorShape const& subTensorShape,
47 unsigned int const* subTensorOrigin) const override
48 {
Jan Eilers8eb25602020-03-09 12:13:48 +000049 IgnoreUnused(parent, subTensorShape, subTensorOrigin);
Derek Lamberti84da38b2019-06-13 11:40:08 +010050 return nullptr;
51 }
52
David Monahanc6e5a6e2019-10-02 09:33:57 +010053 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override
Derek Lamberti84da38b2019-06-13 11:40:08 +010054 {
Jan Eilers8eb25602020-03-09 12:13:48 +000055 IgnoreUnused(tensorInfo);
Derek Lamberti84da38b2019-06-13 11:40:08 +010056 return nullptr;
57 }
58
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010059 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
David Monahanc6e5a6e2019-10-02 09:33:57 +010060 DataLayout dataLayout) const override
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010061 {
Jan Eilers8eb25602020-03-09 12:13:48 +000062 IgnoreUnused(tensorInfo, dataLayout);
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010063 return nullptr;
64 }
65
Ferran Balaguerbfeb2712019-08-07 15:14:56 +010066 const FactoryId& GetId() const override { return m_Id; }
Derek Lamberti84da38b2019-06-13 11:40:08 +010067
Derek Lambertif674aa02019-08-01 15:56:25 +010068 bool SupportsSubTensors() const override { return true; }
69
70 MemorySourceFlags GetExportFlags() const override { return 1; }
Derek Lamberti84da38b2019-06-13 11:40:08 +010071
72private:
73 FactoryId m_Id = "UninitializedId";
74
75 std::weak_ptr<IMemoryManager> m_MemMgr;
76};
77
Derek Lambertif674aa02019-08-01 15:56:25 +010078class TestFactoryImport : public ITensorHandleFactory
79{
80public:
81 TestFactoryImport(std::weak_ptr<IMemoryManager> mgr, ITensorHandleFactory::FactoryId id)
82 : m_Id(id)
83 , m_MemMgr(mgr)
84 {}
85
86 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
87 TensorShape const& subTensorShape,
88 unsigned int const* subTensorOrigin) const override
89 {
Jan Eilers8eb25602020-03-09 12:13:48 +000090 IgnoreUnused(parent, subTensorShape, subTensorOrigin);
Derek Lambertif674aa02019-08-01 15:56:25 +010091 return nullptr;
92 }
93
David Monahanc6e5a6e2019-10-02 09:33:57 +010094 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override
Derek Lambertif674aa02019-08-01 15:56:25 +010095 {
Jan Eilers8eb25602020-03-09 12:13:48 +000096 IgnoreUnused(tensorInfo);
Derek Lambertif674aa02019-08-01 15:56:25 +010097 return nullptr;
98 }
99
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +0100100 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
David Monahanc6e5a6e2019-10-02 09:33:57 +0100101 DataLayout dataLayout) const override
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +0100102 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000103 IgnoreUnused(tensorInfo, dataLayout);
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +0100104 return nullptr;
105 }
106
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100107 const FactoryId& GetId() const override { return m_Id; }
Derek Lambertif674aa02019-08-01 15:56:25 +0100108
109 bool SupportsSubTensors() const override { return true; }
110
111 MemorySourceFlags GetImportFlags() const override { return 1; }
112
113private:
114 FactoryId m_Id = "ImporterId";
115
116 std::weak_ptr<IMemoryManager> m_MemMgr;
117};
118
Derek Lamberti84da38b2019-06-13 11:40:08 +0100119class TestBackendA : public IBackendInternal
120{
121public:
122 TestBackendA() = default;
123
124 const BackendId& GetId() const override { return m_Id; }
125
126 IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override
127 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000128 IgnoreUnused(memoryManager);
Derek Lamberti84da38b2019-06-13 11:40:08 +0100129 return IWorkloadFactoryPtr{};
130 }
131
132 IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
133 {
134 return ILayerSupportSharedPtr{};
135 }
136
137 std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override
138 {
139 return std::vector<ITensorHandleFactory::FactoryId>
140 {
141 "TestHandleFactoryA1",
142 "TestHandleFactoryA2",
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100143 "TestHandleFactoryB1",
144 "TestHandleFactoryD1"
Derek Lamberti84da38b2019-06-13 11:40:08 +0100145 };
146 }
147
148 void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override
149 {
150 auto mgr = std::make_shared<TestMemMgr>();
151
152 registry.RegisterMemoryManager(mgr);
153 registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryA1"));
154 registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryA2"));
155 }
156
157private:
158 BackendId m_Id = "BackendA";
159};
160
161class TestBackendB : public IBackendInternal
162{
163public:
164 TestBackendB() = default;
165
166 const BackendId& GetId() const override { return m_Id; }
167
168 IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override
169 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000170 IgnoreUnused(memoryManager);
Derek Lamberti84da38b2019-06-13 11:40:08 +0100171 return IWorkloadFactoryPtr{};
172 }
173
174 IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
175 {
176 return ILayerSupportSharedPtr{};
177 }
178
179 std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override
180 {
181 return std::vector<ITensorHandleFactory::FactoryId>
182 {
183 "TestHandleFactoryB1"
184 };
185 }
186
187 void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override
188 {
189 auto mgr = std::make_shared<TestMemMgr>();
190
191 registry.RegisterMemoryManager(mgr);
192 registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryB1"));
193 }
194
195private:
196 BackendId m_Id = "BackendB";
197};
198
199class TestBackendC : public IBackendInternal
200{
201public:
202 TestBackendC() = default;
203
204 const BackendId& GetId() const override { return m_Id; }
205
206 IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override
207 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000208 IgnoreUnused(memoryManager);
Derek Lamberti84da38b2019-06-13 11:40:08 +0100209 return IWorkloadFactoryPtr{};
210 }
211
212 IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
213 {
214 return ILayerSupportSharedPtr{};
215 }
216
217 std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override
218 {
219 return std::vector<ITensorHandleFactory::FactoryId>{
220 "TestHandleFactoryC1"
221 };
222 }
223
224 void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override
225 {
226 auto mgr = std::make_shared<TestMemMgr>();
227
228 registry.RegisterMemoryManager(mgr);
229 registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryC1"));
230 }
231
232private:
233 BackendId m_Id = "BackendC";
234};
235
Derek Lambertif674aa02019-08-01 15:56:25 +0100236class TestBackendD : public IBackendInternal
237{
238public:
239 TestBackendD() = default;
240
241 const BackendId& GetId() const override { return m_Id; }
242
243 IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override
244 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000245 IgnoreUnused(memoryManager);
Derek Lambertif674aa02019-08-01 15:56:25 +0100246 return IWorkloadFactoryPtr{};
247 }
248
249 IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
250 {
251 return ILayerSupportSharedPtr{};
252 }
253
254 std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override
255 {
256 return std::vector<ITensorHandleFactory::FactoryId>{
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100257 "TestHandleFactoryD1",
Derek Lambertif674aa02019-08-01 15:56:25 +0100258 };
259 }
260
261 void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override
262 {
263 auto mgr = std::make_shared<TestMemMgr>();
264
265 registry.RegisterMemoryManager(mgr);
266 registry.RegisterFactory(std::make_unique<TestFactoryImport>(mgr, "TestHandleFactoryD1"));
267 }
268
269private:
270 BackendId m_Id = "BackendD";
271};
272
Derek Lamberti84da38b2019-06-13 11:40:08 +0100273
Sadik Armagan1625efc2021-06-10 18:24:34 +0100274TEST_SUITE("TensorHandle")
275{
276TEST_CASE("RegisterFactories")
Derek Lamberti84da38b2019-06-13 11:40:08 +0100277{
278 TestBackendA backendA;
279 TestBackendB backendB;
280
Sadik Armagan1625efc2021-06-10 18:24:34 +0100281 CHECK(backendA.GetHandleFactoryPreferences()[0] == "TestHandleFactoryA1");
282 CHECK(backendA.GetHandleFactoryPreferences()[1] == "TestHandleFactoryA2");
283 CHECK(backendA.GetHandleFactoryPreferences()[2] == "TestHandleFactoryB1");
284 CHECK(backendA.GetHandleFactoryPreferences()[3] == "TestHandleFactoryD1");
Derek Lamberti84da38b2019-06-13 11:40:08 +0100285
286 TensorHandleFactoryRegistry registry;
287 backendA.RegisterTensorHandleFactories(registry);
288 backendB.RegisterTensorHandleFactories(registry);
289
Sadik Armagan1625efc2021-06-10 18:24:34 +0100290 CHECK((registry.GetFactory("Non-existing Backend") == nullptr));
291 CHECK((registry.GetFactory("TestHandleFactoryA1") != nullptr));
292 CHECK((registry.GetFactory("TestHandleFactoryA2") != nullptr));
293 CHECK((registry.GetFactory("TestHandleFactoryB1") != nullptr));
Derek Lamberti84da38b2019-06-13 11:40:08 +0100294}
295
Sadik Armagan1625efc2021-06-10 18:24:34 +0100296TEST_CASE("TensorHandleSelectionStrategy")
Derek Lamberti84da38b2019-06-13 11:40:08 +0100297{
298 auto backendA = std::make_unique<TestBackendA>();
299 auto backendB = std::make_unique<TestBackendB>();
300 auto backendC = std::make_unique<TestBackendC>();
Derek Lambertif674aa02019-08-01 15:56:25 +0100301 auto backendD = std::make_unique<TestBackendD>();
Derek Lamberti84da38b2019-06-13 11:40:08 +0100302
303 TensorHandleFactoryRegistry registry;
304 backendA->RegisterTensorHandleFactories(registry);
305 backendB->RegisterTensorHandleFactories(registry);
306 backendC->RegisterTensorHandleFactories(registry);
Derek Lambertif674aa02019-08-01 15:56:25 +0100307 backendD->RegisterTensorHandleFactories(registry);
Derek Lamberti84da38b2019-06-13 11:40:08 +0100308
309 BackendsMap backends;
310 backends["BackendA"] = std::move(backendA);
311 backends["BackendB"] = std::move(backendB);
312 backends["BackendC"] = std::move(backendC);
Derek Lambertif674aa02019-08-01 15:56:25 +0100313 backends["BackendD"] = std::move(backendD);
Derek Lamberti84da38b2019-06-13 11:40:08 +0100314
315 armnn::Graph graph;
316
317 armnn::InputLayer* const inputLayer = graph.AddLayer<armnn::InputLayer>(0, "input");
318 inputLayer->SetBackendId("BackendA");
319
320 armnn::SoftmaxDescriptor smDesc;
321 armnn::SoftmaxLayer* const softmaxLayer1 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax1");
322 softmaxLayer1->SetBackendId("BackendA");
323
324 armnn::SoftmaxLayer* const softmaxLayer2 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax2");
325 softmaxLayer2->SetBackendId("BackendB");
326
327 armnn::SoftmaxLayer* const softmaxLayer3 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax3");
328 softmaxLayer3->SetBackendId("BackendC");
329
Derek Lambertif674aa02019-08-01 15:56:25 +0100330 armnn::SoftmaxLayer* const softmaxLayer4 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax4");
331 softmaxLayer4->SetBackendId("BackendD");
332
Derek Lamberti84da38b2019-06-13 11:40:08 +0100333 armnn::OutputLayer* const outputLayer = graph.AddLayer<armnn::OutputLayer>(0, "output");
334 outputLayer->SetBackendId("BackendA");
335
336 inputLayer->GetOutputSlot(0).Connect(softmaxLayer1->GetInputSlot(0));
337 softmaxLayer1->GetOutputSlot(0).Connect(softmaxLayer2->GetInputSlot(0));
338 softmaxLayer2->GetOutputSlot(0).Connect(softmaxLayer3->GetInputSlot(0));
Derek Lambertif674aa02019-08-01 15:56:25 +0100339 softmaxLayer3->GetOutputSlot(0).Connect(softmaxLayer4->GetInputSlot(0));
340 softmaxLayer4->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
Derek Lamberti84da38b2019-06-13 11:40:08 +0100341
342 graph.TopologicalSort();
343
344 std::vector<std::string> errors;
Colm Donelan03bf98a2022-05-30 15:20:36 +0100345 auto result = SelectTensorHandleStrategy(graph, backends, registry, true, true, errors);
Derek Lamberti84da38b2019-06-13 11:40:08 +0100346
Sadik Armagan1625efc2021-06-10 18:24:34 +0100347 CHECK(result.m_Error == false);
348 CHECK(result.m_Warning == false);
Derek Lamberti84da38b2019-06-13 11:40:08 +0100349
350 OutputSlot& inputLayerOut = inputLayer->GetOutputSlot(0);
351 OutputSlot& softmaxLayer1Out = softmaxLayer1->GetOutputSlot(0);
352 OutputSlot& softmaxLayer2Out = softmaxLayer2->GetOutputSlot(0);
353 OutputSlot& softmaxLayer3Out = softmaxLayer3->GetOutputSlot(0);
Derek Lambertif674aa02019-08-01 15:56:25 +0100354 OutputSlot& softmaxLayer4Out = softmaxLayer4->GetOutputSlot(0);
Derek Lamberti84da38b2019-06-13 11:40:08 +0100355
356 // Check that the correct factory was selected
Sadik Armagan1625efc2021-06-10 18:24:34 +0100357 CHECK(inputLayerOut.GetTensorHandleFactoryId() == "TestHandleFactoryD1");
358 CHECK(softmaxLayer1Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1");
359 CHECK(softmaxLayer2Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1");
360 CHECK(softmaxLayer3Out.GetTensorHandleFactoryId() == "TestHandleFactoryC1");
361 CHECK(softmaxLayer4Out.GetTensorHandleFactoryId() == "TestHandleFactoryD1");
Derek Lamberti84da38b2019-06-13 11:40:08 +0100362
363 // Check that the correct strategy was selected
Sadik Armagan1625efc2021-06-10 18:24:34 +0100364 CHECK((inputLayerOut.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
365 CHECK((softmaxLayer1Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
366 CHECK((softmaxLayer2Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::CopyToTarget));
367 CHECK((softmaxLayer3Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::ExportToTarget));
368 CHECK((softmaxLayer4Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
Derek Lamberti84da38b2019-06-13 11:40:08 +0100369
Derek Lambertif674aa02019-08-01 15:56:25 +0100370 graph.AddCompatibilityLayers(backends, registry);
371
372 // Test for copy layers
373 int copyCount= 0;
374 graph.ForEachLayer([&copyCount](Layer* layer)
Derek Lamberti84da38b2019-06-13 11:40:08 +0100375 {
376 if (layer->GetType() == LayerType::MemCopy)
377 {
Derek Lambertif674aa02019-08-01 15:56:25 +0100378 copyCount++;
Derek Lamberti84da38b2019-06-13 11:40:08 +0100379 }
380 });
Sadik Armagan1625efc2021-06-10 18:24:34 +0100381 CHECK(copyCount == 1);
Derek Lambertif674aa02019-08-01 15:56:25 +0100382
383 // Test for import layers
384 int importCount= 0;
385 graph.ForEachLayer([&importCount](Layer *layer)
386 {
387 if (layer->GetType() == LayerType::MemImport)
388 {
389 importCount++;
390 }
391 });
Sadik Armagan1625efc2021-06-10 18:24:34 +0100392 CHECK(importCount == 1);
Derek Lamberti84da38b2019-06-13 11:40:08 +0100393}
394
Narumol Prangnawaratb275da52021-12-17 17:27:37 +0000395TEST_CASE("RegisterCopyAndImportFactoryPairTest")
396{
397 TensorHandleFactoryRegistry registry;
398 ITensorHandleFactory::FactoryId copyId = "CopyFactoryId";
399 ITensorHandleFactory::FactoryId importId = "ImportFactoryId";
400 registry.RegisterCopyAndImportFactoryPair(copyId, importId);
401
402 // Get mathing import factory id correctly
403 CHECK((registry.GetMatchingImportFactoryId(copyId) == importId));
404
Narumol Prangnawarat1c52a382022-01-13 11:47:35 +0000405 // Return empty id when Invalid Id is given
Narumol Prangnawaratb275da52021-12-17 17:27:37 +0000406 CHECK((registry.GetMatchingImportFactoryId("InvalidFactoryId") == ""));
407}
408
Sadik Armagan1625efc2021-06-10 18:24:34 +0100409}