blob: e94a6e2d4e99ced93b491349b19904fb77dce711 [file] [log] [blame]
Teresa Charlinad1b3d72023-03-14 12:10:28 +00001//
2// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <DelegateOptions.hpp>
9
10#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14#include <tensorflow/lite/version.h>
15
16#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3)
17#define ARMNN_POST_TFLITE_2_3
18#endif
19
20#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 4)
21#define ARMNN_POST_TFLITE_2_4
22#endif
23
24#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 5)
25#define ARMNN_POST_TFLITE_2_5
26#endif
27
28namespace armnnDelegate
29{
30
31struct DelegateData
32{
33 DelegateData(const std::vector<armnn::BackendId>& backends)
34 : m_Backends(backends)
35 , m_Network(nullptr, nullptr)
36 {}
37
38 const std::vector<armnn::BackendId> m_Backends;
39 armnn::INetworkPtr m_Network;
40 std::vector<armnn::IOutputSlot*> m_OutputSlotForNode;
41};
42
43// Forward decleration for functions initializing the ArmNN Delegate
44DelegateOptions TfLiteArmnnDelegateOptionsDefault();
45
46TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options);
47
48void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate);
49
50TfLiteStatus DoPrepare(TfLiteContext* context, TfLiteDelegate* delegate);
51
52/// ArmNN Delegate
53class Delegate
54{
55 friend class ArmnnSubgraph;
56public:
57 explicit Delegate(armnnDelegate::DelegateOptions options);
58
59 TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteContext* context);
60
61 TfLiteDelegate* GetDelegate();
62
63 /// Retrieve version in X.Y.Z form
64 static const std::string GetVersion();
65
66private:
67 /**
68 * Returns a pointer to the armnn::IRuntime* this will be shared by all armnn_delegates.
69 */
70 armnn::IRuntime* GetRuntime(const armnn::IRuntime::CreationOptions& options)
71 {
72 static armnn::IRuntimePtr instance = armnn::IRuntime::Create(options);
73 // Instantiated on first use.
74 return instance.get();
75 }
76
77 TfLiteDelegate m_Delegate = {
78 reinterpret_cast<void*>(this), // .data_
79 DoPrepare, // .Prepare
80 nullptr, // .CopyFromBufferHandle
81 nullptr, // .CopyToBufferHandle
82 nullptr, // .FreeBufferHandle
83 kTfLiteDelegateFlagsNone, // .flags
84 nullptr, // .opaque_delegate_builder
85 };
86
87 /// ArmNN Runtime pointer
88 armnn::IRuntime* m_Runtime;
89 /// ArmNN Delegate Options
90 armnnDelegate::DelegateOptions m_Options;
91};
92
93/// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph
94class ArmnnSubgraph
95{
96public:
97 static ArmnnSubgraph* Create(TfLiteContext* tfLiteContext,
98 const TfLiteDelegateParams* parameters,
99 const Delegate* delegate);
100
101 TfLiteStatus Prepare(TfLiteContext* tfLiteContext);
102
103 TfLiteStatus Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode);
104
105 static TfLiteStatus VisitNode(DelegateData& delegateData,
106 TfLiteContext* tfLiteContext,
107 TfLiteRegistration* tfLiteRegistration,
108 TfLiteNode* tfLiteNode,
109 int nodeIndex);
110
111private:
112 ArmnnSubgraph(armnn::NetworkId networkId,
113 armnn::IRuntime* runtime,
114 std::vector<armnn::BindingPointInfo>& inputBindings,
115 std::vector<armnn::BindingPointInfo>& outputBindings)
116 : m_NetworkId(networkId), m_Runtime(runtime), m_InputBindings(inputBindings), m_OutputBindings(outputBindings)
117 {}
118
119 static TfLiteStatus AddInputLayer(DelegateData& delegateData,
120 TfLiteContext* tfLiteContext,
121 const TfLiteIntArray* inputs,
122 std::vector<armnn::BindingPointInfo>& inputBindings);
123
124 static TfLiteStatus AddOutputLayer(DelegateData& delegateData,
125 TfLiteContext* tfLiteContext,
126 const TfLiteIntArray* outputs,
127 std::vector<armnn::BindingPointInfo>& outputBindings);
128
129
130 /// The Network Id
131 armnn::NetworkId m_NetworkId;
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +0000132 /// ArmNN Runtime
Teresa Charlinad1b3d72023-03-14 12:10:28 +0000133 armnn::IRuntime* m_Runtime;
134
135 // Binding information for inputs and outputs
136 std::vector<armnn::BindingPointInfo> m_InputBindings;
137 std::vector<armnn::BindingPointInfo> m_OutputBindings;
138
139};
140
141} // armnnDelegate namespace