blob: 6f18185d7bcaac2289ae1468f478a7b755d71345 [file] [log] [blame]
Sadik Armagan3c24f432020-10-19 17:35:30 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Sadik Armagan62483be2020-10-23 17:14:43 +01006#ifndef ARMNN_TFLITE_DELEGATE
7#define ARMNN_TFLITE_DELEGATE
Sadik Armagan3c24f432020-10-19 17:35:30 +01008
9#include "DelegateOptions.hpp"
10
11#include <tensorflow/lite/builtin_ops.h>
12#include <tensorflow/lite/c/builtin_op_data.h>
13#include <tensorflow/lite/c/common.h>
14#include <tensorflow/lite/minimal_logging.h>
15
16namespace armnnDelegate
17{
18
Sadik Armagan62483be2020-10-23 17:14:43 +010019struct DelegateData
20{
21 DelegateData(const std::vector<armnn::BackendId>& backends)
22 : m_Backends(backends)
23 , m_Network(nullptr, nullptr)
24 {}
Sadik Armagan3c24f432020-10-19 17:35:30 +010025
Sadik Armagan62483be2020-10-23 17:14:43 +010026 const std::vector<armnn::BackendId> m_Backends;
27 armnn::INetworkPtr m_Network;
28 std::vector<armnn::IOutputSlot*> m_OutputSlotForNode;
29};
30
31// Forward decleration for functions initializing the ArmNN Delegate
32DelegateOptions TfLiteArmnnDelegateOptionsDefault();
33
34TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options);
35
36void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate);
37
38TfLiteStatus DoPrepare(TfLiteContext* context, TfLiteDelegate* delegate);
39
40/// ArmNN Delegate
Sadik Armagan3c24f432020-10-19 17:35:30 +010041class Delegate
42{
43 friend class ArmnnSubgraph;
44public:
45 explicit Delegate(armnnDelegate::DelegateOptions options);
46
Sadik Armagan62483be2020-10-23 17:14:43 +010047 TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteContext* context);
Sadik Armagan3c24f432020-10-19 17:35:30 +010048
49 TfLiteDelegate* GetDelegate();
50
51private:
52 TfLiteDelegate m_Delegate = {
53 reinterpret_cast<void*>(this), // .data_
Sadik Armagan62483be2020-10-23 17:14:43 +010054 DoPrepare, // .Prepare
Sadik Armagan3c24f432020-10-19 17:35:30 +010055 nullptr, // .CopyFromBufferHandle
56 nullptr, // .CopyToBufferHandle
57 nullptr, // .FreeBufferHandle
58 kTfLiteDelegateFlagsNone, // .flags
59 };
60
Sadik Armagan62483be2020-10-23 17:14:43 +010061 /// ArmNN Runtime pointer
Sadik Armagan3c24f432020-10-19 17:35:30 +010062 armnn::IRuntimePtr m_Runtime;
Sadik Armagan62483be2020-10-23 17:14:43 +010063 /// ArmNN Delegate Options
Sadik Armagan3c24f432020-10-19 17:35:30 +010064 armnnDelegate::DelegateOptions m_Options;
65};
66
67/// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph
68class ArmnnSubgraph
69{
70public:
71 static ArmnnSubgraph* Create(TfLiteContext* tfLiteContext,
72 const TfLiteDelegateParams* parameters,
73 const Delegate* delegate);
74
75 TfLiteStatus Prepare(TfLiteContext* tfLiteContext);
76
Sadik Armagan62483be2020-10-23 17:14:43 +010077 TfLiteStatus Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode);
Sadik Armagan3c24f432020-10-19 17:35:30 +010078
Sadik Armagan62483be2020-10-23 17:14:43 +010079 static TfLiteStatus VisitNode(DelegateData& delegateData,
Sadik Armagan3c24f432020-10-19 17:35:30 +010080 TfLiteContext* tfLiteContext,
81 TfLiteRegistration* tfLiteRegistration,
82 TfLiteNode* tfLiteNode,
83 int nodeIndex);
84
85private:
Sadik Armagan62483be2020-10-23 17:14:43 +010086 ArmnnSubgraph(armnn::NetworkId networkId,
87 armnn::IRuntime* runtime,
88 std::vector<armnn::BindingPointInfo>& inputBindings,
89 std::vector<armnn::BindingPointInfo>& outputBindings)
90 : m_NetworkId(networkId), m_Runtime(runtime), m_InputBindings(inputBindings), m_OutputBindings(outputBindings)
Sadik Armagan3c24f432020-10-19 17:35:30 +010091 {}
92
Sadik Armagan62483be2020-10-23 17:14:43 +010093 static TfLiteStatus AddInputLayer(DelegateData& delegateData,
94 TfLiteContext* tfLiteContext,
95 const TfLiteIntArray* inputs,
96 std::vector<armnn::BindingPointInfo>& inputBindings);
97
98 static TfLiteStatus AddOutputLayer(DelegateData& delegateData,
99 TfLiteContext* tfLiteContext,
100 const TfLiteIntArray* outputs,
101 std::vector<armnn::BindingPointInfo>& outputBindings);
102
103
Sadik Armagan3c24f432020-10-19 17:35:30 +0100104 /// The Network Id
105 armnn::NetworkId m_NetworkId;
106 /// ArmNN Rumtime
107 armnn::IRuntime* m_Runtime;
Sadik Armagan62483be2020-10-23 17:14:43 +0100108
109 // Binding information for inputs and outputs
110 std::vector<armnn::BindingPointInfo> m_InputBindings;
111 std::vector<armnn::BindingPointInfo> m_OutputBindings;
112
Sadik Armagan3c24f432020-10-19 17:35:30 +0100113};
114
Sadik Armagan3c24f432020-10-19 17:35:30 +0100115} // armnnDelegate namespace
116
Sadik Armagan62483be2020-10-23 17:14:43 +0100117#endif // ARMNN_TFLITE_DELEGATE
Sadik Armagan3c24f432020-10-19 17:35:30 +0100118