blob: f8a8aca13900b787f9b1f1b2a3ffc50c9399fc31 [file] [log] [blame]
Sadik Armagan3c24f432020-10-19 17:35:30 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <armnn_delegate.hpp>
7#include <algorithm>
8
9namespace armnnDelegate
10{
11
12Delegate::Delegate(armnnDelegate::DelegateOptions options)
13 : m_Runtime(nullptr, nullptr),
14 m_Options(std::move(options))
15{
16 // Create ArmNN Runtime
17 armnn::IRuntime::CreationOptions runtimeOptions;
18 m_Runtime = armnn::IRuntime::Create(runtimeOptions);
19
20 std::vector<armnn::BackendId> backends;
21
22 if (m_Runtime)
23 {
24 const armnn::BackendIdSet supportedDevices = m_Runtime->GetDeviceSpec().GetSupportedBackends();
25 for (auto& backend : m_Options.GetBackends())
26 {
27 if (std::find(supportedDevices.cbegin(), supportedDevices.cend(), backend) == supportedDevices.cend())
28 {
29 TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
30 "TfLiteArmnnDelegate: Requested unknown backend %s", backend.Get().c_str());
31 }
32 else
33 {
34 backends.push_back(backend);
35 }
36 }
37 }
38
39 if (backends.empty())
40 {
41 // No known backend specified
42 throw armnn::InvalidArgumentException("TfLiteArmnnDelegate: No known backend specified.");
43 }
44 m_Options.SetBackends(backends);
45
46 TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "TfLiteArmnnDelegate: Created TfLite ArmNN delegate.");
47}
48
49TfLiteIntArray* Delegate::CollectOperatorsToDelegate(TfLiteContext* tfLiteContext)
50{
51 TfLiteIntArray* executionPlan = nullptr;
52 if (tfLiteContext->GetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk)
53 {
54 TF_LITE_KERNEL_LOG(tfLiteContext, "TfLiteArmnnDelegate: Unable to get graph execution plan.");
55 return nullptr;
56 }
57
58 // Null INetworkPtr
59 armnn::INetworkPtr nullNetworkPtr(nullptr, nullptr);
60
61 TfLiteIntArray* nodesToDelegate = TfLiteIntArrayCreate(executionPlan->size);
62 nodesToDelegate->size = 0;
63 for (int i = 0; i < executionPlan->size; ++i)
64 {
65 const int nodeIndex = executionPlan->data[i];
66
67 // If TfLite nodes can be delegated to ArmNN
68 TfLiteNode* tfLiteNode = nullptr;
69 TfLiteRegistration* tfLiteRegistration = nullptr;
70 if (tfLiteContext->GetNodeAndRegistration(
71 tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk)
72 {
73 TF_LITE_KERNEL_LOG(tfLiteContext,
74 "TfLiteArmnnDelegate: Unable to get node and registration for node %d.",
75 nodeIndex);
76 continue;
77 }
78
79 if (ArmnnSubgraph::VisitNode(
80 nullNetworkPtr, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex) != kTfLiteOk)
81 {
82 // node is not supported by ArmNN
83 continue;
84 }
85
86 nodesToDelegate->data[nodesToDelegate->size++] = nodeIndex;
87 }
88
89 std::sort(&nodesToDelegate->data[0],
90 &nodesToDelegate->data[nodesToDelegate->size]);
91
92 return nodesToDelegate;
93}
94
95TfLiteDelegate* Delegate::GetDelegate()
96{
97 return &m_Delegate;
98}
99
100ArmnnSubgraph* ArmnnSubgraph::Create(TfLiteContext* tfLiteContext,
101 const TfLiteDelegateParams* parameters,
102 const Delegate* delegate)
103{
104 TfLiteIntArray* executionPlan;
105 if (tfLiteContext->GetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk)
106 {
107 return nullptr;
108 }
109
110 // Construct ArmNN network
111 using NetworkOptions = std::vector<armnn::BackendOptions>;
112 armnn::NetworkOptions networkOptions = {};
113 armnn::NetworkId networkId;
114 armnn::INetworkPtr network = armnn::INetwork::Create(networkOptions);
115
116 // Parse TfLite delegate nodes to ArmNN nodes
117 for (int i = 0; i < parameters->nodes_to_replace->size; ++i)
118 {
119 const int nodeIndex = parameters->nodes_to_replace->data[i];
120
121 TfLiteNode* tfLiteNode = nullptr;
122 TfLiteRegistration* tfLiteRegistration = nullptr;
123 if (tfLiteContext->GetNodeAndRegistration(
124 tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk)
125 {
126 throw armnn::Exception("TfLiteArmnnDelegate: Unable to get node registration: " + nodeIndex);
127 }
128
129 if (VisitNode(network, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex) != kTfLiteOk)
130 {
131 throw armnn::Exception("TfLiteArmnnDelegate: Unable to parse node: " + nodeIndex);
132 }
133 }
134
135 // Optimise Arm NN network
136 armnn::IOptimizedNetworkPtr optNet =
137 armnn::Optimize(*network, delegate->m_Options.GetBackends(), delegate->m_Runtime->GetDeviceSpec());
138 if (!optNet)
139 {
140 // Optimize Failed
141 throw armnn::Exception("TfLiteArmnnDelegate: Unable to optimize the network!");
142 }
143 // Load graph into runtime
144 delegate->m_Runtime->LoadNetwork(networkId, std::move(optNet));
145
146 // Create a new SubGraph with networkId and runtime
147 return new ArmnnSubgraph(networkId, delegate->m_Runtime.get());
148}
149
150TfLiteStatus ArmnnSubgraph::Prepare(TfLiteContext* tfLiteContext)
151{
152 return kTfLiteOk;
153}
154
155TfLiteStatus ArmnnSubgraph::Invoke(TfLiteContext* tfLiteContext)
156{
157 /// Get the Input Tensors and OutputTensors from the context
158 /// Execute the network
159 //m_Runtime->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
160
161 return kTfLiteOk;
162}
163
164TfLiteStatus ArmnnSubgraph::VisitNode(armnn::INetworkPtr& network,
165 TfLiteContext* tfLiteContext,
166 TfLiteRegistration* tfLiteRegistration,
167 TfLiteNode* tfLiteNode,
168 int nodeIndex)
169{
170 /*
171 * Take the node and check what operator it is and VisitXXXLayer()
172 * In the VisitXXXLayer() function parse TfLite node to Arm NN Layer and add it to tho network graph
173 *switch (tfLiteRegistration->builtin_code)
174 * {
175 * case kTfLiteBuiltinAbs:
176 * return VisitAbsLayer(...);
177 * ...
178 * default:
179 * return kTfLiteError;
180 * }
181 */
182 return kTfLiteError;
183}
184
185} // armnnDelegate namespace