blob: 6136f2bebe80c6aa940a11084372bcd1b9d46da8 [file] [log] [blame]
//
// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include "DelegateOptions.hpp"
#include <tensorflow/lite/builtin_ops.h>
#include <tensorflow/lite/c/builtin_op_data.h>
#include <tensorflow/lite/c/common.h>
#include <tensorflow/lite/minimal_logging.h>
namespace armnnDelegate
{
TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate);
/// Delegate class
class Delegate
{
friend class ArmnnSubgraph;
public:
explicit Delegate(armnnDelegate::DelegateOptions options);
TfLiteIntArray* CollectOperatorsToDelegate(TfLiteContext* context);
TfLiteDelegate* GetDelegate();
private:
TfLiteDelegate m_Delegate = {
reinterpret_cast<void*>(this), // .data_
DelegatePrepare, // .Prepare
nullptr, // .CopyFromBufferHandle
nullptr, // .CopyToBufferHandle
nullptr, // .FreeBufferHandle
kTfLiteDelegateFlagsNone, // .flags
};
/// Arm NN Runtime pointer
armnn::IRuntimePtr m_Runtime;
/// Arm NN Delegate Options
armnnDelegate::DelegateOptions m_Options;
};
/// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph
class ArmnnSubgraph
{
public:
static ArmnnSubgraph* Create(TfLiteContext* tfLiteContext,
const TfLiteDelegateParams* parameters,
const Delegate* delegate);
TfLiteStatus Prepare(TfLiteContext* tfLiteContext);
TfLiteStatus Invoke(TfLiteContext* tfLiteContext);
static TfLiteStatus VisitNode(armnn::INetworkPtr& network,
TfLiteContext* tfLiteContext,
TfLiteRegistration* tfLiteRegistration,
TfLiteNode* tfLiteNode,
int nodeIndex);
private:
ArmnnSubgraph(armnn::NetworkId networkId, armnn::IRuntime* runtime)
: m_NetworkId(networkId), m_Runtime(runtime)
{}
/// The Network Id
armnn::NetworkId m_NetworkId;
/// ArmNN Rumtime
armnn::IRuntime* m_Runtime;
};
void* ArmnnSubgraphInit(TfLiteContext* tfLiteContext, const char* buffer, size_t length)
{
const TfLiteDelegateParams* parameters = reinterpret_cast<const TfLiteDelegateParams*>(buffer);
return static_cast<void*>(ArmnnSubgraph::Create(
tfLiteContext, parameters, static_cast<::armnnDelegate::Delegate*>(parameters->delegate->data_)));
}
TfLiteStatus ArmnnSubgraphPrepare(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode)
{
if (tfLiteNode->user_data == nullptr)
{
return kTfLiteError;
}
return static_cast<ArmnnSubgraph*>(tfLiteNode->user_data)->Prepare(tfLiteContext);
}
TfLiteStatus ArmnnSubgraphInvoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode)
{
if (tfLiteNode->user_data == nullptr)
{
return kTfLiteError;
}
return static_cast<ArmnnSubgraph*>(tfLiteNode->user_data)->Invoke(tfLiteContext);
}
void ArmnnSubgraphFree(TfLiteContext* tfLiteContext, void* buffer)
{
if (buffer != nullptr)
{
delete static_cast<ArmnnSubgraph*>(buffer);
}
}
const TfLiteRegistration armnnSubgraphRegistration = {
ArmnnSubgraphInit, // .init
ArmnnSubgraphFree, // .free
ArmnnSubgraphPrepare, // .prepare
ArmnnSubgraphInvoke, // .invoke
nullptr, // .profiling_string
0, // .builtin_code
"TfLiteArmnnDelegate", // .custom_name
1, // .version
};
TfLiteStatus DelegatePrepare(TfLiteContext* tfLiteContext, TfLiteDelegate* tfLiteDelegate)
{
TfLiteIntArray* supportedOperators =
static_cast<::armnnDelegate::Delegate*>(tfLiteDelegate->data_)->CollectOperatorsToDelegate(tfLiteContext);
const TfLiteStatus status =
tfLiteContext->ReplaceNodeSubsetsWithDelegateKernels(
tfLiteContext, armnnSubgraphRegistration, supportedOperators, tfLiteDelegate);
TfLiteIntArrayFree(supportedOperators);
return status;
}
} // armnnDelegate namespace
armnnDelegate::DelegateOptions TfLiteArmnnDelegateOptionsDefault() {
armnnDelegate::DelegateOptions options(armnn::Compute::CpuRef);
return options;
}
TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options)
{
auto* armnnDelegate = new ::armnnDelegate::Delegate(options);
return armnnDelegate->GetDelegate();
}
void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate)
{
if (tfLiteDelegate != nullptr)
{
delete static_cast<::armnnDelegate::Delegate*>(tfLiteDelegate->data_);
}
}