blob: 6136f2bebe80c6aa940a11084372bcd1b9d46da8 [file] [log] [blame]
Sadik Armagan3c24f432020-10-19 17:35:30 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "DelegateOptions.hpp"
9
10#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnDelegate
16{
17
18TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate);
19
20/// Delegate class
21class Delegate
22{
23 friend class ArmnnSubgraph;
24public:
25 explicit Delegate(armnnDelegate::DelegateOptions options);
26
27 TfLiteIntArray* CollectOperatorsToDelegate(TfLiteContext* context);
28
29 TfLiteDelegate* GetDelegate();
30
31private:
32 TfLiteDelegate m_Delegate = {
33 reinterpret_cast<void*>(this), // .data_
34 DelegatePrepare, // .Prepare
35 nullptr, // .CopyFromBufferHandle
36 nullptr, // .CopyToBufferHandle
37 nullptr, // .FreeBufferHandle
38 kTfLiteDelegateFlagsNone, // .flags
39 };
40
41 /// Arm NN Runtime pointer
42 armnn::IRuntimePtr m_Runtime;
43 /// Arm NN Delegate Options
44 armnnDelegate::DelegateOptions m_Options;
45};
46
47/// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph
48class ArmnnSubgraph
49{
50public:
51 static ArmnnSubgraph* Create(TfLiteContext* tfLiteContext,
52 const TfLiteDelegateParams* parameters,
53 const Delegate* delegate);
54
55 TfLiteStatus Prepare(TfLiteContext* tfLiteContext);
56
57 TfLiteStatus Invoke(TfLiteContext* tfLiteContext);
58
59 static TfLiteStatus VisitNode(armnn::INetworkPtr& network,
60 TfLiteContext* tfLiteContext,
61 TfLiteRegistration* tfLiteRegistration,
62 TfLiteNode* tfLiteNode,
63 int nodeIndex);
64
65private:
66 ArmnnSubgraph(armnn::NetworkId networkId, armnn::IRuntime* runtime)
67 : m_NetworkId(networkId), m_Runtime(runtime)
68 {}
69
70 /// The Network Id
71 armnn::NetworkId m_NetworkId;
72 /// ArmNN Rumtime
73 armnn::IRuntime* m_Runtime;
74};
75
76void* ArmnnSubgraphInit(TfLiteContext* tfLiteContext, const char* buffer, size_t length)
77{
78 const TfLiteDelegateParams* parameters = reinterpret_cast<const TfLiteDelegateParams*>(buffer);
79
80 return static_cast<void*>(ArmnnSubgraph::Create(
81 tfLiteContext, parameters, static_cast<::armnnDelegate::Delegate*>(parameters->delegate->data_)));
82}
83
84TfLiteStatus ArmnnSubgraphPrepare(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode)
85{
86 if (tfLiteNode->user_data == nullptr)
87 {
88 return kTfLiteError;
89 }
90
91 return static_cast<ArmnnSubgraph*>(tfLiteNode->user_data)->Prepare(tfLiteContext);
92}
93
94TfLiteStatus ArmnnSubgraphInvoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode)
95{
96 if (tfLiteNode->user_data == nullptr)
97 {
98 return kTfLiteError;
99 }
100
101 return static_cast<ArmnnSubgraph*>(tfLiteNode->user_data)->Invoke(tfLiteContext);
102}
103
104void ArmnnSubgraphFree(TfLiteContext* tfLiteContext, void* buffer)
105{
106 if (buffer != nullptr)
107 {
108 delete static_cast<ArmnnSubgraph*>(buffer);
109 }
110}
111
112const TfLiteRegistration armnnSubgraphRegistration = {
113 ArmnnSubgraphInit, // .init
114 ArmnnSubgraphFree, // .free
115 ArmnnSubgraphPrepare, // .prepare
116 ArmnnSubgraphInvoke, // .invoke
117 nullptr, // .profiling_string
118 0, // .builtin_code
119 "TfLiteArmnnDelegate", // .custom_name
120 1, // .version
121};
122
123TfLiteStatus DelegatePrepare(TfLiteContext* tfLiteContext, TfLiteDelegate* tfLiteDelegate)
124{
125 TfLiteIntArray* supportedOperators =
126 static_cast<::armnnDelegate::Delegate*>(tfLiteDelegate->data_)->CollectOperatorsToDelegate(tfLiteContext);
127
128 const TfLiteStatus status =
129 tfLiteContext->ReplaceNodeSubsetsWithDelegateKernels(
130 tfLiteContext, armnnSubgraphRegistration, supportedOperators, tfLiteDelegate);
131 TfLiteIntArrayFree(supportedOperators);
132
133 return status;
134}
135
136} // armnnDelegate namespace
137
138armnnDelegate::DelegateOptions TfLiteArmnnDelegateOptionsDefault() {
139 armnnDelegate::DelegateOptions options(armnn::Compute::CpuRef);
140 return options;
141}
142
143TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options)
144{
145 auto* armnnDelegate = new ::armnnDelegate::Delegate(options);
146 return armnnDelegate->GetDelegate();
147}
148
149void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate)
150{
151 if (tfLiteDelegate != nullptr)
152 {
153 delete static_cast<::armnnDelegate::Delegate*>(tfLiteDelegate->data_);
154 }
155}