blob: 79ab4bf79c3771453b734fd866a6ce5b1217173b [file] [log] [blame]
Sadik Armagan3c24f432020-10-19 17:35:30 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Sadik Armagan5d03e312020-11-17 16:43:56 +00006#pragma once
Sadik Armagan3c24f432020-10-19 17:35:30 +01007
8#include "DelegateOptions.hpp"
9
10#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
Matthew Sloyan81ec9942021-10-12 10:26:30 +010014#include <tensorflow/lite/version.h>
15
16#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3)
17#define ARMNN_POST_TFLITE_2_3
18#endif
19
Cathal Corbette126be92022-05-25 11:21:11 +010020#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 4)
21#define ARMNN_POST_TFLITE_2_4
22#endif
23
Matthew Sloyan81ec9942021-10-12 10:26:30 +010024#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 5)
25#define ARMNN_POST_TFLITE_2_5
26#endif
Sadik Armagan3c24f432020-10-19 17:35:30 +010027
28namespace armnnDelegate
29{
30
Sadik Armagan62483be2020-10-23 17:14:43 +010031struct DelegateData
32{
33 DelegateData(const std::vector<armnn::BackendId>& backends)
34 : m_Backends(backends)
35 , m_Network(nullptr, nullptr)
36 {}
Sadik Armagan3c24f432020-10-19 17:35:30 +010037
Sadik Armagan62483be2020-10-23 17:14:43 +010038 const std::vector<armnn::BackendId> m_Backends;
39 armnn::INetworkPtr m_Network;
40 std::vector<armnn::IOutputSlot*> m_OutputSlotForNode;
41};
42
43// Forward decleration for functions initializing the ArmNN Delegate
44DelegateOptions TfLiteArmnnDelegateOptionsDefault();
45
46TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options);
47
48void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate);
49
50TfLiteStatus DoPrepare(TfLiteContext* context, TfLiteDelegate* delegate);
51
52/// ArmNN Delegate
Sadik Armagan3c24f432020-10-19 17:35:30 +010053class Delegate
54{
55 friend class ArmnnSubgraph;
56public:
57 explicit Delegate(armnnDelegate::DelegateOptions options);
58
Sadik Armagan62483be2020-10-23 17:14:43 +010059 TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteContext* context);
Sadik Armagan3c24f432020-10-19 17:35:30 +010060
61 TfLiteDelegate* GetDelegate();
62
Matthew Sloyanac001ee2021-02-03 10:43:04 +000063 /// Retrieve version in X.Y.Z form
64 static const std::string GetVersion();
65
Sadik Armagan3c24f432020-10-19 17:35:30 +010066private:
67 TfLiteDelegate m_Delegate = {
68 reinterpret_cast<void*>(this), // .data_
Sadik Armagan62483be2020-10-23 17:14:43 +010069 DoPrepare, // .Prepare
Sadik Armagan3c24f432020-10-19 17:35:30 +010070 nullptr, // .CopyFromBufferHandle
71 nullptr, // .CopyToBufferHandle
72 nullptr, // .FreeBufferHandle
73 kTfLiteDelegateFlagsNone, // .flags
74 };
75
Sadik Armagan62483be2020-10-23 17:14:43 +010076 /// ArmNN Runtime pointer
Sadik Armagan3c24f432020-10-19 17:35:30 +010077 armnn::IRuntimePtr m_Runtime;
Sadik Armagan62483be2020-10-23 17:14:43 +010078 /// ArmNN Delegate Options
Sadik Armagan3c24f432020-10-19 17:35:30 +010079 armnnDelegate::DelegateOptions m_Options;
80};
81
82/// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph
83class ArmnnSubgraph
84{
85public:
86 static ArmnnSubgraph* Create(TfLiteContext* tfLiteContext,
87 const TfLiteDelegateParams* parameters,
88 const Delegate* delegate);
89
90 TfLiteStatus Prepare(TfLiteContext* tfLiteContext);
91
Sadik Armagan62483be2020-10-23 17:14:43 +010092 TfLiteStatus Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode);
Sadik Armagan3c24f432020-10-19 17:35:30 +010093
Sadik Armagan62483be2020-10-23 17:14:43 +010094 static TfLiteStatus VisitNode(DelegateData& delegateData,
Sadik Armagan3c24f432020-10-19 17:35:30 +010095 TfLiteContext* tfLiteContext,
96 TfLiteRegistration* tfLiteRegistration,
97 TfLiteNode* tfLiteNode,
98 int nodeIndex);
99
100private:
Sadik Armagan62483be2020-10-23 17:14:43 +0100101 ArmnnSubgraph(armnn::NetworkId networkId,
102 armnn::IRuntime* runtime,
103 std::vector<armnn::BindingPointInfo>& inputBindings,
104 std::vector<armnn::BindingPointInfo>& outputBindings)
105 : m_NetworkId(networkId), m_Runtime(runtime), m_InputBindings(inputBindings), m_OutputBindings(outputBindings)
Sadik Armagan3c24f432020-10-19 17:35:30 +0100106 {}
107
Sadik Armagan62483be2020-10-23 17:14:43 +0100108 static TfLiteStatus AddInputLayer(DelegateData& delegateData,
109 TfLiteContext* tfLiteContext,
110 const TfLiteIntArray* inputs,
111 std::vector<armnn::BindingPointInfo>& inputBindings);
112
113 static TfLiteStatus AddOutputLayer(DelegateData& delegateData,
114 TfLiteContext* tfLiteContext,
115 const TfLiteIntArray* outputs,
116 std::vector<armnn::BindingPointInfo>& outputBindings);
117
118
Sadik Armagan3c24f432020-10-19 17:35:30 +0100119 /// The Network Id
120 armnn::NetworkId m_NetworkId;
121 /// ArmNN Rumtime
122 armnn::IRuntime* m_Runtime;
Sadik Armagan62483be2020-10-23 17:14:43 +0100123
124 // Binding information for inputs and outputs
125 std::vector<armnn::BindingPointInfo> m_InputBindings;
126 std::vector<armnn::BindingPointInfo> m_OutputBindings;
127
Sadik Armagan3c24f432020-10-19 17:35:30 +0100128};
129
Sadik Armagan3c24f432020-10-19 17:35:30 +0100130} // armnnDelegate namespace
131
Sadik Armagan3c24f432020-10-19 17:35:30 +0100132