Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 5 | |
| 6 | #include <doctest/doctest.h> |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 7 | |
| 8 | #include <armnn/LayerVisitorBase.hpp> |
| 9 | |
Matteo Martincigh | e5b8eb9 | 2019-11-28 15:45:42 +0000 | [diff] [blame] | 10 | #include <armnn/backends/IBackendContext.hpp> |
| 11 | #include <armnn/backends/IBackendInternal.hpp> |
| 12 | #include <armnn/backends/IMemoryManager.hpp> |
| 13 | #include <armnn/backends/ITensorHandleFactory.hpp> |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 14 | #include <backendsCommon/TensorHandleFactoryRegistry.hpp> |
| 15 | |
| 16 | #include <optimizations/Optimization.hpp> |
| 17 | |
| 18 | #include <Network.hpp> |
| 19 | |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 20 | #include <armnn/utility/IgnoreUnused.hpp> |
| 21 | |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 22 | #include <vector> |
| 23 | #include <string> |
| 24 | |
Francis Murtagh | b4f312c | 2019-12-31 12:44:20 +0000 | [diff] [blame] | 25 | |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 26 | using namespace armnn; |
| 27 | |
| 28 | class TestMemMgr : public IMemoryManager |
| 29 | { |
| 30 | public: |
| 31 | TestMemMgr() = default; |
| 32 | |
| 33 | void Acquire() override {} |
| 34 | void Release() override {} |
| 35 | }; |
| 36 | |
| 37 | class TestFactory1 : public ITensorHandleFactory |
| 38 | { |
| 39 | public: |
| 40 | TestFactory1(std::weak_ptr<IMemoryManager> mgr, ITensorHandleFactory::FactoryId id) |
| 41 | : m_Id(id) |
| 42 | , m_MemMgr(mgr) |
| 43 | {} |
| 44 | |
| 45 | std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent, |
| 46 | TensorShape const& subTensorShape, |
| 47 | unsigned int const* subTensorOrigin) const override |
| 48 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 49 | IgnoreUnused(parent, subTensorShape, subTensorOrigin); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 50 | return nullptr; |
| 51 | } |
| 52 | |
David Monahan | c6e5a6e | 2019-10-02 09:33:57 +0100 | [diff] [blame] | 53 | std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 54 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 55 | IgnoreUnused(tensorInfo); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 56 | return nullptr; |
| 57 | } |
| 58 | |
Narumol Prangnawarat | 4e3e818 | 2019-08-14 12:25:50 +0100 | [diff] [blame] | 59 | std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, |
David Monahan | c6e5a6e | 2019-10-02 09:33:57 +0100 | [diff] [blame] | 60 | DataLayout dataLayout) const override |
Narumol Prangnawarat | 4e3e818 | 2019-08-14 12:25:50 +0100 | [diff] [blame] | 61 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 62 | IgnoreUnused(tensorInfo, dataLayout); |
Narumol Prangnawarat | 4e3e818 | 2019-08-14 12:25:50 +0100 | [diff] [blame] | 63 | return nullptr; |
| 64 | } |
| 65 | |
Ferran Balaguer | bfeb271 | 2019-08-07 15:14:56 +0100 | [diff] [blame] | 66 | const FactoryId& GetId() const override { return m_Id; } |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 67 | |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 68 | bool SupportsSubTensors() const override { return true; } |
| 69 | |
| 70 | MemorySourceFlags GetExportFlags() const override { return 1; } |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 71 | |
| 72 | private: |
| 73 | FactoryId m_Id = "UninitializedId"; |
| 74 | |
| 75 | std::weak_ptr<IMemoryManager> m_MemMgr; |
| 76 | }; |
| 77 | |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 78 | class TestFactoryImport : public ITensorHandleFactory |
| 79 | { |
| 80 | public: |
| 81 | TestFactoryImport(std::weak_ptr<IMemoryManager> mgr, ITensorHandleFactory::FactoryId id) |
| 82 | : m_Id(id) |
| 83 | , m_MemMgr(mgr) |
| 84 | {} |
| 85 | |
| 86 | std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent, |
| 87 | TensorShape const& subTensorShape, |
| 88 | unsigned int const* subTensorOrigin) const override |
| 89 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 90 | IgnoreUnused(parent, subTensorShape, subTensorOrigin); |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 91 | return nullptr; |
| 92 | } |
| 93 | |
David Monahan | c6e5a6e | 2019-10-02 09:33:57 +0100 | [diff] [blame] | 94 | std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 95 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 96 | IgnoreUnused(tensorInfo); |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 97 | return nullptr; |
| 98 | } |
| 99 | |
Narumol Prangnawarat | 4e3e818 | 2019-08-14 12:25:50 +0100 | [diff] [blame] | 100 | std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, |
David Monahan | c6e5a6e | 2019-10-02 09:33:57 +0100 | [diff] [blame] | 101 | DataLayout dataLayout) const override |
Narumol Prangnawarat | 4e3e818 | 2019-08-14 12:25:50 +0100 | [diff] [blame] | 102 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 103 | IgnoreUnused(tensorInfo, dataLayout); |
Narumol Prangnawarat | 4e3e818 | 2019-08-14 12:25:50 +0100 | [diff] [blame] | 104 | return nullptr; |
| 105 | } |
| 106 | |
Ferran Balaguer | bfeb271 | 2019-08-07 15:14:56 +0100 | [diff] [blame] | 107 | const FactoryId& GetId() const override { return m_Id; } |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 108 | |
| 109 | bool SupportsSubTensors() const override { return true; } |
| 110 | |
| 111 | MemorySourceFlags GetImportFlags() const override { return 1; } |
| 112 | |
| 113 | private: |
| 114 | FactoryId m_Id = "ImporterId"; |
| 115 | |
| 116 | std::weak_ptr<IMemoryManager> m_MemMgr; |
| 117 | }; |
| 118 | |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 119 | class TestBackendA : public IBackendInternal |
| 120 | { |
| 121 | public: |
| 122 | TestBackendA() = default; |
| 123 | |
| 124 | const BackendId& GetId() const override { return m_Id; } |
| 125 | |
| 126 | IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override |
| 127 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 128 | IgnoreUnused(memoryManager); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 129 | return IWorkloadFactoryPtr{}; |
| 130 | } |
| 131 | |
| 132 | IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override |
| 133 | { |
| 134 | return ILayerSupportSharedPtr{}; |
| 135 | } |
| 136 | |
| 137 | std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override |
| 138 | { |
| 139 | return std::vector<ITensorHandleFactory::FactoryId> |
| 140 | { |
| 141 | "TestHandleFactoryA1", |
| 142 | "TestHandleFactoryA2", |
Narumol Prangnawarat | e5f0b24 | 2021-05-07 17:52:36 +0100 | [diff] [blame] | 143 | "TestHandleFactoryB1", |
| 144 | "TestHandleFactoryD1" |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 145 | }; |
| 146 | } |
| 147 | |
| 148 | void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override |
| 149 | { |
| 150 | auto mgr = std::make_shared<TestMemMgr>(); |
| 151 | |
| 152 | registry.RegisterMemoryManager(mgr); |
| 153 | registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryA1")); |
| 154 | registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryA2")); |
| 155 | } |
| 156 | |
| 157 | private: |
| 158 | BackendId m_Id = "BackendA"; |
| 159 | }; |
| 160 | |
| 161 | class TestBackendB : public IBackendInternal |
| 162 | { |
| 163 | public: |
| 164 | TestBackendB() = default; |
| 165 | |
| 166 | const BackendId& GetId() const override { return m_Id; } |
| 167 | |
| 168 | IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override |
| 169 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 170 | IgnoreUnused(memoryManager); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 171 | return IWorkloadFactoryPtr{}; |
| 172 | } |
| 173 | |
| 174 | IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override |
| 175 | { |
| 176 | return ILayerSupportSharedPtr{}; |
| 177 | } |
| 178 | |
| 179 | std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override |
| 180 | { |
| 181 | return std::vector<ITensorHandleFactory::FactoryId> |
| 182 | { |
| 183 | "TestHandleFactoryB1" |
| 184 | }; |
| 185 | } |
| 186 | |
| 187 | void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override |
| 188 | { |
| 189 | auto mgr = std::make_shared<TestMemMgr>(); |
| 190 | |
| 191 | registry.RegisterMemoryManager(mgr); |
| 192 | registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryB1")); |
| 193 | } |
| 194 | |
| 195 | private: |
| 196 | BackendId m_Id = "BackendB"; |
| 197 | }; |
| 198 | |
| 199 | class TestBackendC : public IBackendInternal |
| 200 | { |
| 201 | public: |
| 202 | TestBackendC() = default; |
| 203 | |
| 204 | const BackendId& GetId() const override { return m_Id; } |
| 205 | |
| 206 | IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override |
| 207 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 208 | IgnoreUnused(memoryManager); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 209 | return IWorkloadFactoryPtr{}; |
| 210 | } |
| 211 | |
| 212 | IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override |
| 213 | { |
| 214 | return ILayerSupportSharedPtr{}; |
| 215 | } |
| 216 | |
| 217 | std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override |
| 218 | { |
| 219 | return std::vector<ITensorHandleFactory::FactoryId>{ |
| 220 | "TestHandleFactoryC1" |
| 221 | }; |
| 222 | } |
| 223 | |
| 224 | void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override |
| 225 | { |
| 226 | auto mgr = std::make_shared<TestMemMgr>(); |
| 227 | |
| 228 | registry.RegisterMemoryManager(mgr); |
| 229 | registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryC1")); |
| 230 | } |
| 231 | |
| 232 | private: |
| 233 | BackendId m_Id = "BackendC"; |
| 234 | }; |
| 235 | |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 236 | class TestBackendD : public IBackendInternal |
| 237 | { |
| 238 | public: |
| 239 | TestBackendD() = default; |
| 240 | |
| 241 | const BackendId& GetId() const override { return m_Id; } |
| 242 | |
| 243 | IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override |
| 244 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 245 | IgnoreUnused(memoryManager); |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 246 | return IWorkloadFactoryPtr{}; |
| 247 | } |
| 248 | |
| 249 | IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override |
| 250 | { |
| 251 | return ILayerSupportSharedPtr{}; |
| 252 | } |
| 253 | |
| 254 | std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override |
| 255 | { |
| 256 | return std::vector<ITensorHandleFactory::FactoryId>{ |
Narumol Prangnawarat | e5f0b24 | 2021-05-07 17:52:36 +0100 | [diff] [blame] | 257 | "TestHandleFactoryD1", |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 258 | }; |
| 259 | } |
| 260 | |
| 261 | void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override |
| 262 | { |
| 263 | auto mgr = std::make_shared<TestMemMgr>(); |
| 264 | |
| 265 | registry.RegisterMemoryManager(mgr); |
| 266 | registry.RegisterFactory(std::make_unique<TestFactoryImport>(mgr, "TestHandleFactoryD1")); |
| 267 | } |
| 268 | |
| 269 | private: |
| 270 | BackendId m_Id = "BackendD"; |
| 271 | }; |
| 272 | |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 273 | |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 274 | TEST_SUITE("TensorHandle") |
| 275 | { |
| 276 | TEST_CASE("RegisterFactories") |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 277 | { |
| 278 | TestBackendA backendA; |
| 279 | TestBackendB backendB; |
| 280 | |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 281 | CHECK(backendA.GetHandleFactoryPreferences()[0] == "TestHandleFactoryA1"); |
| 282 | CHECK(backendA.GetHandleFactoryPreferences()[1] == "TestHandleFactoryA2"); |
| 283 | CHECK(backendA.GetHandleFactoryPreferences()[2] == "TestHandleFactoryB1"); |
| 284 | CHECK(backendA.GetHandleFactoryPreferences()[3] == "TestHandleFactoryD1"); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 285 | |
| 286 | TensorHandleFactoryRegistry registry; |
| 287 | backendA.RegisterTensorHandleFactories(registry); |
| 288 | backendB.RegisterTensorHandleFactories(registry); |
| 289 | |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 290 | CHECK((registry.GetFactory("Non-existing Backend") == nullptr)); |
| 291 | CHECK((registry.GetFactory("TestHandleFactoryA1") != nullptr)); |
| 292 | CHECK((registry.GetFactory("TestHandleFactoryA2") != nullptr)); |
| 293 | CHECK((registry.GetFactory("TestHandleFactoryB1") != nullptr)); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 294 | } |
| 295 | |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 296 | TEST_CASE("TensorHandleSelectionStrategy") |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 297 | { |
| 298 | auto backendA = std::make_unique<TestBackendA>(); |
| 299 | auto backendB = std::make_unique<TestBackendB>(); |
| 300 | auto backendC = std::make_unique<TestBackendC>(); |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 301 | auto backendD = std::make_unique<TestBackendD>(); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 302 | |
| 303 | TensorHandleFactoryRegistry registry; |
| 304 | backendA->RegisterTensorHandleFactories(registry); |
| 305 | backendB->RegisterTensorHandleFactories(registry); |
| 306 | backendC->RegisterTensorHandleFactories(registry); |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 307 | backendD->RegisterTensorHandleFactories(registry); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 308 | |
| 309 | BackendsMap backends; |
| 310 | backends["BackendA"] = std::move(backendA); |
| 311 | backends["BackendB"] = std::move(backendB); |
| 312 | backends["BackendC"] = std::move(backendC); |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 313 | backends["BackendD"] = std::move(backendD); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 314 | |
| 315 | armnn::Graph graph; |
| 316 | |
| 317 | armnn::InputLayer* const inputLayer = graph.AddLayer<armnn::InputLayer>(0, "input"); |
| 318 | inputLayer->SetBackendId("BackendA"); |
| 319 | |
| 320 | armnn::SoftmaxDescriptor smDesc; |
| 321 | armnn::SoftmaxLayer* const softmaxLayer1 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax1"); |
| 322 | softmaxLayer1->SetBackendId("BackendA"); |
| 323 | |
| 324 | armnn::SoftmaxLayer* const softmaxLayer2 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax2"); |
| 325 | softmaxLayer2->SetBackendId("BackendB"); |
| 326 | |
| 327 | armnn::SoftmaxLayer* const softmaxLayer3 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax3"); |
| 328 | softmaxLayer3->SetBackendId("BackendC"); |
| 329 | |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 330 | armnn::SoftmaxLayer* const softmaxLayer4 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax4"); |
| 331 | softmaxLayer4->SetBackendId("BackendD"); |
| 332 | |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 333 | armnn::OutputLayer* const outputLayer = graph.AddLayer<armnn::OutputLayer>(0, "output"); |
| 334 | outputLayer->SetBackendId("BackendA"); |
| 335 | |
| 336 | inputLayer->GetOutputSlot(0).Connect(softmaxLayer1->GetInputSlot(0)); |
| 337 | softmaxLayer1->GetOutputSlot(0).Connect(softmaxLayer2->GetInputSlot(0)); |
| 338 | softmaxLayer2->GetOutputSlot(0).Connect(softmaxLayer3->GetInputSlot(0)); |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 339 | softmaxLayer3->GetOutputSlot(0).Connect(softmaxLayer4->GetInputSlot(0)); |
| 340 | softmaxLayer4->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 341 | |
| 342 | graph.TopologicalSort(); |
| 343 | |
| 344 | std::vector<std::string> errors; |
Colm Donelan | 03bf98a | 2022-05-30 15:20:36 +0100 | [diff] [blame^] | 345 | auto result = SelectTensorHandleStrategy(graph, backends, registry, true, true, errors); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 346 | |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 347 | CHECK(result.m_Error == false); |
| 348 | CHECK(result.m_Warning == false); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 349 | |
| 350 | OutputSlot& inputLayerOut = inputLayer->GetOutputSlot(0); |
| 351 | OutputSlot& softmaxLayer1Out = softmaxLayer1->GetOutputSlot(0); |
| 352 | OutputSlot& softmaxLayer2Out = softmaxLayer2->GetOutputSlot(0); |
| 353 | OutputSlot& softmaxLayer3Out = softmaxLayer3->GetOutputSlot(0); |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 354 | OutputSlot& softmaxLayer4Out = softmaxLayer4->GetOutputSlot(0); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 355 | |
| 356 | // Check that the correct factory was selected |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 357 | CHECK(inputLayerOut.GetTensorHandleFactoryId() == "TestHandleFactoryD1"); |
| 358 | CHECK(softmaxLayer1Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1"); |
| 359 | CHECK(softmaxLayer2Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1"); |
| 360 | CHECK(softmaxLayer3Out.GetTensorHandleFactoryId() == "TestHandleFactoryC1"); |
| 361 | CHECK(softmaxLayer4Out.GetTensorHandleFactoryId() == "TestHandleFactoryD1"); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 362 | |
| 363 | // Check that the correct strategy was selected |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 364 | CHECK((inputLayerOut.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility)); |
| 365 | CHECK((softmaxLayer1Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility)); |
| 366 | CHECK((softmaxLayer2Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::CopyToTarget)); |
| 367 | CHECK((softmaxLayer3Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::ExportToTarget)); |
| 368 | CHECK((softmaxLayer4Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility)); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 369 | |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 370 | graph.AddCompatibilityLayers(backends, registry); |
| 371 | |
| 372 | // Test for copy layers |
| 373 | int copyCount= 0; |
| 374 | graph.ForEachLayer([©Count](Layer* layer) |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 375 | { |
| 376 | if (layer->GetType() == LayerType::MemCopy) |
| 377 | { |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 378 | copyCount++; |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 379 | } |
| 380 | }); |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 381 | CHECK(copyCount == 1); |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 382 | |
| 383 | // Test for import layers |
| 384 | int importCount= 0; |
| 385 | graph.ForEachLayer([&importCount](Layer *layer) |
| 386 | { |
| 387 | if (layer->GetType() == LayerType::MemImport) |
| 388 | { |
| 389 | importCount++; |
| 390 | } |
| 391 | }); |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 392 | CHECK(importCount == 1); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 393 | } |
| 394 | |
Narumol Prangnawarat | b275da5 | 2021-12-17 17:27:37 +0000 | [diff] [blame] | 395 | TEST_CASE("RegisterCopyAndImportFactoryPairTest") |
| 396 | { |
| 397 | TensorHandleFactoryRegistry registry; |
| 398 | ITensorHandleFactory::FactoryId copyId = "CopyFactoryId"; |
| 399 | ITensorHandleFactory::FactoryId importId = "ImportFactoryId"; |
| 400 | registry.RegisterCopyAndImportFactoryPair(copyId, importId); |
| 401 | |
| 402 | // Get mathing import factory id correctly |
| 403 | CHECK((registry.GetMatchingImportFactoryId(copyId) == importId)); |
| 404 | |
Narumol Prangnawarat | 1c52a38 | 2022-01-13 11:47:35 +0000 | [diff] [blame] | 405 | // Return empty id when Invalid Id is given |
Narumol Prangnawarat | b275da5 | 2021-12-17 17:27:37 +0000 | [diff] [blame] | 406 | CHECK((registry.GetMatchingImportFactoryId("InvalidFactoryId") == "")); |
| 407 | } |
| 408 | |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 409 | } |