blob: a6957dc7d585431c9f9c773d7918fb4d020201ec [file] [log] [blame]
Teresa Charlinad1b3d72023-03-14 12:10:28 +00001//
Colm Donelan253d1bb2024-03-04 22:19:26 +00002// Copyright © 2020-2024 Arm Ltd and Contributors. All rights reserved.
Teresa Charlinad1b3d72023-03-14 12:10:28 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <DelegateOptions.hpp>
9
10#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14#include <tensorflow/lite/version.h>
15
16#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3)
17#define ARMNN_POST_TFLITE_2_3
18#endif
19
20#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 4)
21#define ARMNN_POST_TFLITE_2_4
22#endif
23
24#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 5)
25#define ARMNN_POST_TFLITE_2_5
26#endif
27
28namespace armnnDelegate
29{
30
31struct DelegateData
32{
33 DelegateData(const std::vector<armnn::BackendId>& backends)
34 : m_Backends(backends)
35 , m_Network(nullptr, nullptr)
36 {}
37
38 const std::vector<armnn::BackendId> m_Backends;
39 armnn::INetworkPtr m_Network;
40 std::vector<armnn::IOutputSlot*> m_OutputSlotForNode;
41};
42
43// Forward decleration for functions initializing the ArmNN Delegate
44DelegateOptions TfLiteArmnnDelegateOptionsDefault();
45
46TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options);
47
48void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate);
49
50TfLiteStatus DoPrepare(TfLiteContext* context, TfLiteDelegate* delegate);
51
52/// ArmNN Delegate
53class Delegate
54{
55 friend class ArmnnSubgraph;
56public:
57 explicit Delegate(armnnDelegate::DelegateOptions options);
58
59 TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteContext* context);
60
61 TfLiteDelegate* GetDelegate();
62
63 /// Retrieve version in X.Y.Z form
64 static const std::string GetVersion();
65
66private:
67 /**
68 * Returns a pointer to the armnn::IRuntime* this will be shared by all armnn_delegates.
69 */
70 armnn::IRuntime* GetRuntime(const armnn::IRuntime::CreationOptions& options)
71 {
72 static armnn::IRuntimePtr instance = armnn::IRuntime::Create(options);
73 // Instantiated on first use.
74 return instance.get();
75 }
76
77 TfLiteDelegate m_Delegate = {
78 reinterpret_cast<void*>(this), // .data_
79 DoPrepare, // .Prepare
80 nullptr, // .CopyFromBufferHandle
81 nullptr, // .CopyToBufferHandle
82 nullptr, // .FreeBufferHandle
83 kTfLiteDelegateFlagsNone, // .flags
84 nullptr, // .opaque_delegate_builder
85 };
86
87 /// ArmNN Runtime pointer
88 armnn::IRuntime* m_Runtime;
89 /// ArmNN Delegate Options
90 armnnDelegate::DelegateOptions m_Options;
91};
92
93/// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph
94class ArmnnSubgraph
95{
96public:
97 static ArmnnSubgraph* Create(TfLiteContext* tfLiteContext,
98 const TfLiteDelegateParams* parameters,
99 const Delegate* delegate);
100
Colm Donelan253d1bb2024-03-04 22:19:26 +0000101 ~ArmnnSubgraph();
102
Teresa Charlinad1b3d72023-03-14 12:10:28 +0000103 TfLiteStatus Prepare(TfLiteContext* tfLiteContext);
104
105 TfLiteStatus Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode);
106
107 static TfLiteStatus VisitNode(DelegateData& delegateData,
108 TfLiteContext* tfLiteContext,
109 TfLiteRegistration* tfLiteRegistration,
110 TfLiteNode* tfLiteNode,
111 int nodeIndex);
112
113private:
114 ArmnnSubgraph(armnn::NetworkId networkId,
115 armnn::IRuntime* runtime,
116 std::vector<armnn::BindingPointInfo>& inputBindings,
117 std::vector<armnn::BindingPointInfo>& outputBindings)
118 : m_NetworkId(networkId), m_Runtime(runtime), m_InputBindings(inputBindings), m_OutputBindings(outputBindings)
119 {}
120
121 static TfLiteStatus AddInputLayer(DelegateData& delegateData,
122 TfLiteContext* tfLiteContext,
123 const TfLiteIntArray* inputs,
124 std::vector<armnn::BindingPointInfo>& inputBindings);
125
126 static TfLiteStatus AddOutputLayer(DelegateData& delegateData,
127 TfLiteContext* tfLiteContext,
128 const TfLiteIntArray* outputs,
129 std::vector<armnn::BindingPointInfo>& outputBindings);
130
131
132 /// The Network Id
133 armnn::NetworkId m_NetworkId;
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +0000134 /// ArmNN Runtime
Teresa Charlinad1b3d72023-03-14 12:10:28 +0000135 armnn::IRuntime* m_Runtime;
136
137 // Binding information for inputs and outputs
138 std::vector<armnn::BindingPointInfo> m_InputBindings;
139 std::vector<armnn::BindingPointInfo> m_OutputBindings;
140
141};
142
143} // armnnDelegate namespace