blob: db653391e2b934f4078786fdfeea957ae5dfd72f [file] [log] [blame]
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001//
Teresa Charlin571a4f72024-03-26 11:18:42 +00002// Copyright © 2022-2024 Arm Ltd and Contributors. All rights reserved.
Francis Murtagh9270d9e2022-08-12 13:54:17 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "TosaRefBackend.hpp"
7#include "TosaRefBackendId.hpp"
8#include "TosaRefWorkloadFactory.hpp"
9#include "TosaRefLayerSupport.hpp"
10#include "TosaRefTensorHandleFactory.hpp"
11
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010012#include <tosaCommon/TosaMappings.hpp>
Francis Murtagh9270d9e2022-08-12 13:54:17 +010013#include <armnn/BackendRegistry.hpp>
14#include <armnn/backends/IBackendContext.hpp>
15#include <armnn/backends/IMemoryManager.hpp>
16#include <armnn/utility/PolymorphicDowncast.hpp>
17#include <backendsCommon/DefaultAllocator.hpp>
18#include <backendsCommon/SubgraphUtils.hpp>
19
20#include <Optimizer.hpp>
21
22namespace armnn
23{
24
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010025// Utility function to construct a valid Deleter for TosaSerializationHandler ptrs passed back to ArmNN
26template <typename T>
27void DeleteAsType(const void* const blob)
28{
29 delete static_cast<const T*>(blob);
30}
31
Francis Murtagh9270d9e2022-08-12 13:54:17 +010032const BackendId& TosaRefBackend::GetIdStatic()
33{
34 static const BackendId s_Id{TosaRefBackendId()};
35 return s_Id;
36}
37
38IBackendInternal::IWorkloadFactoryPtr TosaRefBackend::CreateWorkloadFactory(
39 const IBackendInternal::IMemoryManagerSharedPtr& memoryManager) const
40{
41 return std::make_unique<TosaRefWorkloadFactory>(PolymorphicPointerDowncast<TosaRefMemoryManager>(memoryManager));
42}
43
44IBackendInternal::IWorkloadFactoryPtr TosaRefBackend::CreateWorkloadFactory(
45 class TensorHandleFactoryRegistry& tensorHandleFactoryRegistry) const
46{
47 auto memoryManager = std::make_shared<TosaRefMemoryManager>();
48
49 tensorHandleFactoryRegistry.RegisterMemoryManager(memoryManager);
50
51 auto factory = std::make_unique<TosaRefTensorHandleFactory>(memoryManager);
52 // Register copy and import factory pair
53 tensorHandleFactoryRegistry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId());
54 // Register the factory
55 tensorHandleFactoryRegistry.RegisterFactory(std::move(factory));
56
57 return std::make_unique<TosaRefWorkloadFactory>(PolymorphicPointerDowncast<TosaRefMemoryManager>(memoryManager));
58}
59
60IBackendInternal::IBackendContextPtr TosaRefBackend::CreateBackendContext(const IRuntime::CreationOptions&) const
61{
62 return IBackendContextPtr{};
63}
64
65IBackendInternal::IBackendProfilingContextPtr TosaRefBackend::CreateBackendProfilingContext(
66 const IRuntime::CreationOptions&, IBackendProfilingPtr&)
67{
68 return IBackendProfilingContextPtr{};
69}
70
71IBackendInternal::IMemoryManagerUniquePtr TosaRefBackend::CreateMemoryManager() const
72{
73 return std::make_unique<TosaRefMemoryManager>();
74}
75
76IBackendInternal::ILayerSupportSharedPtr TosaRefBackend::GetLayerSupport() const
77{
78 static ILayerSupportSharedPtr layerSupport{new TosaRefLayerSupport};
79 return layerSupport;
80}
81
82OptimizationViews TosaRefBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
83 const ModelOptions& modelOptions) const
84{
85 OptimizationViews optimizationViews(modelOptions);
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000086
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010087 auto handler = std::make_unique<TosaSerializationHandler>();
Francis Murtagh9270d9e2022-08-12 13:54:17 +010088
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000089 std::vector<std::string> graphInputs;
90 std::vector<std::string> graphOutputs;
91
92 std::vector<TosaSerializationOperator*> operators;
93 std::vector<TosaSerializationTensor*> tensors;
Matthew Sloyan5c54c382022-11-09 16:28:51 +000094
Teresa Charlin571a4f72024-03-26 11:18:42 +000095 auto it = subgraph.begin();
96 while (it != subgraph.end())
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010097 {
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +000098 Layer& base = *(PolymorphicDowncast<Layer*>(*it));
Matthew Sloyan164bf4f2022-10-28 18:02:17 +010099
100 if(base.GetType() == armnn::LayerType::Input ||
101 base.GetType() == armnn::LayerType::Output)
102 {
Teresa Charlin571a4f72024-03-26 11:18:42 +0000103 ++it;
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100104 continue;
105 }
106
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000107 tosa::TosaSerializationBasicBlock* mappings = GetTosaMappingFromLayer(&base);
Matthew Sloyan5c54c382022-11-09 16:28:51 +0000108
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000109 // Loop through inputs to see if there are any graph inputs, if so save them.
110 // If it's an input to the graph "input" can be found in the string.
Teresa Charlin571a4f72024-03-26 11:18:42 +0000111 for (const std::string& blockInputName : mappings->GetInputs())
Matthew Sloyan5c54c382022-11-09 16:28:51 +0000112 {
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000113 if (blockInputName.find("input") != std::string::npos)
114 {
115 graphInputs.push_back(blockInputName);
116 }
Matthew Sloyan5c54c382022-11-09 16:28:51 +0000117 }
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000118
119 // Loop through outputs to see if there are any graph outputs, if so save them.
120 // If it's an output to the graph "output" can be found in the string.
Teresa Charlin571a4f72024-03-26 11:18:42 +0000121 for (const std::string& blockOutputName : mappings->GetOutputs())
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000122 {
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000123 if (blockOutputName.find("output") != std::string::npos)
124 {
125 graphOutputs.push_back(blockOutputName);
126 }
127 }
128
129 auto blockOperators = mappings->GetOperators();
130 operators.insert(operators.end(), blockOperators.begin(), blockOperators.end());
131
132 auto blockTensors = mappings->GetTensors();
133 tensors.insert(tensors.end(), blockTensors.begin(), blockTensors.end());
Teresa Charlin571a4f72024-03-26 11:18:42 +0000134
135 ++it;
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100136 }
137
Narumol Prangnawaratad323af2023-09-29 17:00:38 +0100138 // Add all mappings to main block.
139 auto* block = new TosaSerializationBasicBlock("main", "main", operators, tensors, graphInputs, graphOutputs);
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000140
Narumol Prangnawaratad323af2023-09-29 17:00:38 +0100141 std::vector<TosaSerializationBasicBlock*> blocks;
142 blocks.emplace_back(block);
143
144 // Add blocks to the main region.
145 auto* region = new TosaSerializationRegion("main", blocks);
146 handler->GetRegions().emplace_back(region);
Matthew Sloyanc5fe6e72022-11-25 16:10:00 +0000147
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100148 auto compiledBlob =
149 std::make_unique<PreCompiledObjectPtr>(handler.release(), DeleteAsType<TosaSerializationHandler>);
150
151 IConnectableLayer* preCompiledLayer = optimizationViews.GetINetwork()->AddPrecompiledLayer(
152 PreCompiledDescriptor(subgraph.GetNumInputSlots(), subgraph.GetNumOutputSlots()),
153 std::move(*compiledBlob),
154 armnn::Optional<BackendId>(GetId()),
155 "TOSA_Pre_Compiled_Layer");
156
157 // Copy the output tensor infos from sub-graph
158 for (unsigned int i = 0; i < subgraph.GetNumOutputSlots(); i++)
159 {
160 preCompiledLayer->GetOutputSlot(i).SetTensorInfo(subgraph.GetIOutputSlot(i)->GetTensorInfo());
161 }
162
163 optimizationViews.AddSubstitution({ std::move(subgraph), SubgraphView(preCompiledLayer) });
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100164 return optimizationViews;
165}
166
Matthew Sloyan164bf4f2022-10-28 18:02:17 +0100167
Francis Murtagh9270d9e2022-08-12 13:54:17 +0100168std::vector<ITensorHandleFactory::FactoryId> TosaRefBackend::GetHandleFactoryPreferences() const
169{
170 return std::vector<ITensorHandleFactory::FactoryId> { TosaRefTensorHandleFactory::GetIdStatic() };
171}
172
173void TosaRefBackend::RegisterTensorHandleFactories(class TensorHandleFactoryRegistry& registry)
174{
175 auto memoryManager = std::make_shared<TosaRefMemoryManager>();
176
177 registry.RegisterMemoryManager(memoryManager);
178
179 auto factory = std::make_unique<TosaRefTensorHandleFactory>(memoryManager);
180
181 // Register copy and import factory pair
182 registry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId());
183 // Register the factory
184 registry.RegisterFactory(std::move(factory));
185}
186
187std::unique_ptr<ICustomAllocator> TosaRefBackend::GetDefaultAllocator() const
188{
189 return std::make_unique<DefaultAllocator>();
190}
191
192} // namespace armnn