IVGCVSW-5389 'TfLiteDelegate: Implement the FullyConnected operator'
* Added FullyConnected operator support to delegate
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: Iae9c0980a4bfd6aa4d90f107f329dfa782baeefe
diff --git a/delegate/src/armnn_delegate.cpp b/delegate/src/armnn_delegate.cpp
index 82cf573..69bd4f7 100644
--- a/delegate/src/armnn_delegate.cpp
+++ b/delegate/src/armnn_delegate.cpp
@@ -85,7 +85,6 @@
{
return kTfLiteError;
}
-
return static_cast<ArmnnSubgraph*>(tfLiteNode->user_data)->Prepare(tfLiteContext);
},
// ArmnnSubgraph Invoke
@@ -209,6 +208,11 @@
{
const int32_t tensorId = inputs->data[i];
const TfLiteTensor tensor = tfLiteContext->tensors[tensorId];
+ // Do not create bindings for constant inputs
+ if (tensor.allocation_type == kTfLiteMmapRo)
+ {
+ continue;
+ }
auto bindingId = static_cast<armnn::LayerBindingId>((tensorId));
armnn::IConnectableLayer* layer = delegateData.m_Network->AddInputLayer(bindingId);
@@ -220,12 +224,9 @@
// Store for creating connections
delegateData.m_OutputSlotForNode[tensorId] = &outputSlot;
- // Do not create bindings for constant inputs
- if (tensor.allocation_type != kTfLiteMmapRo)
- {
- inputBindings.push_back(std::make_pair(bindingId, tensorInfo));
- }
+ inputBindings.push_back(std::make_pair(bindingId, tensorInfo));
}
+
return kTfLiteOk;
}
@@ -244,7 +245,6 @@
armnn::IConnectableLayer* layer = delegateData.m_Network->AddOutputLayer(bindingId);
auto tensorInfo = GetTensorInfoForTfLiteTensor(tensor);
-
ARMNN_ASSERT(delegateData.m_OutputSlotForNode[tensorId] != nullptr);
delegateData.m_OutputSlotForNode[tensorId]->Connect(layer->GetInputSlot(0));
outputBindings.push_back(std::make_pair(bindingId, tensorInfo));
@@ -272,7 +272,8 @@
armnn::NetworkId networkId;
delegateData.m_Network = armnn::INetwork::Create(networkOptions);
- delegateData.m_OutputSlotForNode = std::vector<armnn::IOutputSlot*>(parameters->nodes_to_replace->size, nullptr);
+ delegateData.m_OutputSlotForNode = std::vector<armnn::IOutputSlot*>(tfLiteContext->tensors_size, nullptr);
+
std::vector<armnn::BindingPointInfo> inputBindings;
std::vector<armnn::BindingPointInfo> outputBindings;
@@ -314,8 +315,7 @@
armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
try
{
-
- optNet = armnn::Optimize(*(delegateData.m_Network),
+ optNet = armnn::Optimize(*(delegateData.m_Network.get()),
delegate->m_Options.GetBackends(),
delegate->m_Runtime->GetDeviceSpec());
}