blob: adf264aabbcc24749fa06a3f874a45a8bfb6f7f9 [file] [log] [blame]
Sadik Armagan3c24f432020-10-19 17:35:30 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Sadik Armagan5d03e312020-11-17 16:43:56 +00006#pragma once
Sadik Armagan3c24f432020-10-19 17:35:30 +01007
8#include "DelegateOptions.hpp"
9
10#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnDelegate
16{
17
Sadik Armagan62483be2020-10-23 17:14:43 +010018struct DelegateData
19{
20 DelegateData(const std::vector<armnn::BackendId>& backends)
21 : m_Backends(backends)
22 , m_Network(nullptr, nullptr)
23 {}
Sadik Armagan3c24f432020-10-19 17:35:30 +010024
Sadik Armagan62483be2020-10-23 17:14:43 +010025 const std::vector<armnn::BackendId> m_Backends;
26 armnn::INetworkPtr m_Network;
27 std::vector<armnn::IOutputSlot*> m_OutputSlotForNode;
28};
29
30// Forward decleration for functions initializing the ArmNN Delegate
31DelegateOptions TfLiteArmnnDelegateOptionsDefault();
32
33TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options);
34
35void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate);
36
37TfLiteStatus DoPrepare(TfLiteContext* context, TfLiteDelegate* delegate);
38
39/// ArmNN Delegate
Sadik Armagan3c24f432020-10-19 17:35:30 +010040class Delegate
41{
42 friend class ArmnnSubgraph;
43public:
44 explicit Delegate(armnnDelegate::DelegateOptions options);
45
Sadik Armagan62483be2020-10-23 17:14:43 +010046 TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteContext* context);
Sadik Armagan3c24f432020-10-19 17:35:30 +010047
48 TfLiteDelegate* GetDelegate();
49
50private:
51 TfLiteDelegate m_Delegate = {
52 reinterpret_cast<void*>(this), // .data_
Sadik Armagan62483be2020-10-23 17:14:43 +010053 DoPrepare, // .Prepare
Sadik Armagan3c24f432020-10-19 17:35:30 +010054 nullptr, // .CopyFromBufferHandle
55 nullptr, // .CopyToBufferHandle
56 nullptr, // .FreeBufferHandle
57 kTfLiteDelegateFlagsNone, // .flags
58 };
59
Sadik Armagan62483be2020-10-23 17:14:43 +010060 /// ArmNN Runtime pointer
Sadik Armagan3c24f432020-10-19 17:35:30 +010061 armnn::IRuntimePtr m_Runtime;
Sadik Armagan62483be2020-10-23 17:14:43 +010062 /// ArmNN Delegate Options
Sadik Armagan3c24f432020-10-19 17:35:30 +010063 armnnDelegate::DelegateOptions m_Options;
64};
65
66/// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph
67class ArmnnSubgraph
68{
69public:
70 static ArmnnSubgraph* Create(TfLiteContext* tfLiteContext,
71 const TfLiteDelegateParams* parameters,
72 const Delegate* delegate);
73
74 TfLiteStatus Prepare(TfLiteContext* tfLiteContext);
75
Sadik Armagan62483be2020-10-23 17:14:43 +010076 TfLiteStatus Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode);
Sadik Armagan3c24f432020-10-19 17:35:30 +010077
Sadik Armagan62483be2020-10-23 17:14:43 +010078 static TfLiteStatus VisitNode(DelegateData& delegateData,
Sadik Armagan3c24f432020-10-19 17:35:30 +010079 TfLiteContext* tfLiteContext,
80 TfLiteRegistration* tfLiteRegistration,
81 TfLiteNode* tfLiteNode,
82 int nodeIndex);
83
84private:
Sadik Armagan62483be2020-10-23 17:14:43 +010085 ArmnnSubgraph(armnn::NetworkId networkId,
86 armnn::IRuntime* runtime,
87 std::vector<armnn::BindingPointInfo>& inputBindings,
88 std::vector<armnn::BindingPointInfo>& outputBindings)
89 : m_NetworkId(networkId), m_Runtime(runtime), m_InputBindings(inputBindings), m_OutputBindings(outputBindings)
Sadik Armagan3c24f432020-10-19 17:35:30 +010090 {}
91
Sadik Armagan62483be2020-10-23 17:14:43 +010092 static TfLiteStatus AddInputLayer(DelegateData& delegateData,
93 TfLiteContext* tfLiteContext,
94 const TfLiteIntArray* inputs,
95 std::vector<armnn::BindingPointInfo>& inputBindings);
96
97 static TfLiteStatus AddOutputLayer(DelegateData& delegateData,
98 TfLiteContext* tfLiteContext,
99 const TfLiteIntArray* outputs,
100 std::vector<armnn::BindingPointInfo>& outputBindings);
101
102
Sadik Armagan3c24f432020-10-19 17:35:30 +0100103 /// The Network Id
104 armnn::NetworkId m_NetworkId;
105 /// ArmNN Rumtime
106 armnn::IRuntime* m_Runtime;
Sadik Armagan62483be2020-10-23 17:14:43 +0100107
108 // Binding information for inputs and outputs
109 std::vector<armnn::BindingPointInfo> m_InputBindings;
110 std::vector<armnn::BindingPointInfo> m_OutputBindings;
111
Sadik Armagan3c24f432020-10-19 17:35:30 +0100112};
113
Sadik Armagan3c24f432020-10-19 17:35:30 +0100114} // armnnDelegate namespace
115
Sadik Armagan3c24f432020-10-19 17:35:30 +0100116