Sadik Armagan | 3c24f43 | 2020-10-19 17:35:30 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include <armnn_delegate.hpp> |
| 7 | #include <algorithm> |
| 8 | |
| 9 | namespace armnnDelegate |
| 10 | { |
| 11 | |
| 12 | Delegate::Delegate(armnnDelegate::DelegateOptions options) |
| 13 | : m_Runtime(nullptr, nullptr), |
| 14 | m_Options(std::move(options)) |
| 15 | { |
| 16 | // Create ArmNN Runtime |
| 17 | armnn::IRuntime::CreationOptions runtimeOptions; |
| 18 | m_Runtime = armnn::IRuntime::Create(runtimeOptions); |
| 19 | |
| 20 | std::vector<armnn::BackendId> backends; |
| 21 | |
| 22 | if (m_Runtime) |
| 23 | { |
| 24 | const armnn::BackendIdSet supportedDevices = m_Runtime->GetDeviceSpec().GetSupportedBackends(); |
| 25 | for (auto& backend : m_Options.GetBackends()) |
| 26 | { |
| 27 | if (std::find(supportedDevices.cbegin(), supportedDevices.cend(), backend) == supportedDevices.cend()) |
| 28 | { |
| 29 | TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, |
| 30 | "TfLiteArmnnDelegate: Requested unknown backend %s", backend.Get().c_str()); |
| 31 | } |
| 32 | else |
| 33 | { |
| 34 | backends.push_back(backend); |
| 35 | } |
| 36 | } |
| 37 | } |
| 38 | |
| 39 | if (backends.empty()) |
| 40 | { |
| 41 | // No known backend specified |
| 42 | throw armnn::InvalidArgumentException("TfLiteArmnnDelegate: No known backend specified."); |
| 43 | } |
| 44 | m_Options.SetBackends(backends); |
| 45 | |
| 46 | TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "TfLiteArmnnDelegate: Created TfLite ArmNN delegate."); |
| 47 | } |
| 48 | |
| 49 | TfLiteIntArray* Delegate::CollectOperatorsToDelegate(TfLiteContext* tfLiteContext) |
| 50 | { |
| 51 | TfLiteIntArray* executionPlan = nullptr; |
| 52 | if (tfLiteContext->GetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk) |
| 53 | { |
| 54 | TF_LITE_KERNEL_LOG(tfLiteContext, "TfLiteArmnnDelegate: Unable to get graph execution plan."); |
| 55 | return nullptr; |
| 56 | } |
| 57 | |
| 58 | // Null INetworkPtr |
| 59 | armnn::INetworkPtr nullNetworkPtr(nullptr, nullptr); |
| 60 | |
| 61 | TfLiteIntArray* nodesToDelegate = TfLiteIntArrayCreate(executionPlan->size); |
| 62 | nodesToDelegate->size = 0; |
| 63 | for (int i = 0; i < executionPlan->size; ++i) |
| 64 | { |
| 65 | const int nodeIndex = executionPlan->data[i]; |
| 66 | |
| 67 | // If TfLite nodes can be delegated to ArmNN |
| 68 | TfLiteNode* tfLiteNode = nullptr; |
| 69 | TfLiteRegistration* tfLiteRegistration = nullptr; |
| 70 | if (tfLiteContext->GetNodeAndRegistration( |
| 71 | tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk) |
| 72 | { |
| 73 | TF_LITE_KERNEL_LOG(tfLiteContext, |
| 74 | "TfLiteArmnnDelegate: Unable to get node and registration for node %d.", |
| 75 | nodeIndex); |
| 76 | continue; |
| 77 | } |
| 78 | |
| 79 | if (ArmnnSubgraph::VisitNode( |
| 80 | nullNetworkPtr, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex) != kTfLiteOk) |
| 81 | { |
| 82 | // node is not supported by ArmNN |
| 83 | continue; |
| 84 | } |
| 85 | |
| 86 | nodesToDelegate->data[nodesToDelegate->size++] = nodeIndex; |
| 87 | } |
| 88 | |
| 89 | std::sort(&nodesToDelegate->data[0], |
| 90 | &nodesToDelegate->data[nodesToDelegate->size]); |
| 91 | |
| 92 | return nodesToDelegate; |
| 93 | } |
| 94 | |
| 95 | TfLiteDelegate* Delegate::GetDelegate() |
| 96 | { |
| 97 | return &m_Delegate; |
| 98 | } |
| 99 | |
| 100 | ArmnnSubgraph* ArmnnSubgraph::Create(TfLiteContext* tfLiteContext, |
| 101 | const TfLiteDelegateParams* parameters, |
| 102 | const Delegate* delegate) |
| 103 | { |
| 104 | TfLiteIntArray* executionPlan; |
| 105 | if (tfLiteContext->GetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk) |
| 106 | { |
| 107 | return nullptr; |
| 108 | } |
| 109 | |
| 110 | // Construct ArmNN network |
| 111 | using NetworkOptions = std::vector<armnn::BackendOptions>; |
| 112 | armnn::NetworkOptions networkOptions = {}; |
| 113 | armnn::NetworkId networkId; |
| 114 | armnn::INetworkPtr network = armnn::INetwork::Create(networkOptions); |
| 115 | |
| 116 | // Parse TfLite delegate nodes to ArmNN nodes |
| 117 | for (int i = 0; i < parameters->nodes_to_replace->size; ++i) |
| 118 | { |
| 119 | const int nodeIndex = parameters->nodes_to_replace->data[i]; |
| 120 | |
| 121 | TfLiteNode* tfLiteNode = nullptr; |
| 122 | TfLiteRegistration* tfLiteRegistration = nullptr; |
| 123 | if (tfLiteContext->GetNodeAndRegistration( |
| 124 | tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk) |
| 125 | { |
| 126 | throw armnn::Exception("TfLiteArmnnDelegate: Unable to get node registration: " + nodeIndex); |
| 127 | } |
| 128 | |
| 129 | if (VisitNode(network, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex) != kTfLiteOk) |
| 130 | { |
| 131 | throw armnn::Exception("TfLiteArmnnDelegate: Unable to parse node: " + nodeIndex); |
| 132 | } |
| 133 | } |
| 134 | |
| 135 | // Optimise Arm NN network |
| 136 | armnn::IOptimizedNetworkPtr optNet = |
| 137 | armnn::Optimize(*network, delegate->m_Options.GetBackends(), delegate->m_Runtime->GetDeviceSpec()); |
| 138 | if (!optNet) |
| 139 | { |
| 140 | // Optimize Failed |
| 141 | throw armnn::Exception("TfLiteArmnnDelegate: Unable to optimize the network!"); |
| 142 | } |
| 143 | // Load graph into runtime |
| 144 | delegate->m_Runtime->LoadNetwork(networkId, std::move(optNet)); |
| 145 | |
| 146 | // Create a new SubGraph with networkId and runtime |
| 147 | return new ArmnnSubgraph(networkId, delegate->m_Runtime.get()); |
| 148 | } |
| 149 | |
| 150 | TfLiteStatus ArmnnSubgraph::Prepare(TfLiteContext* tfLiteContext) |
| 151 | { |
| 152 | return kTfLiteOk; |
| 153 | } |
| 154 | |
| 155 | TfLiteStatus ArmnnSubgraph::Invoke(TfLiteContext* tfLiteContext) |
| 156 | { |
| 157 | /// Get the Input Tensors and OutputTensors from the context |
| 158 | /// Execute the network |
| 159 | //m_Runtime->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors); |
| 160 | |
| 161 | return kTfLiteOk; |
| 162 | } |
| 163 | |
| 164 | TfLiteStatus ArmnnSubgraph::VisitNode(armnn::INetworkPtr& network, |
| 165 | TfLiteContext* tfLiteContext, |
| 166 | TfLiteRegistration* tfLiteRegistration, |
| 167 | TfLiteNode* tfLiteNode, |
| 168 | int nodeIndex) |
| 169 | { |
| 170 | /* |
| 171 | * Take the node and check what operator it is and VisitXXXLayer() |
| 172 | * In the VisitXXXLayer() function parse TfLite node to Arm NN Layer and add it to tho network graph |
| 173 | *switch (tfLiteRegistration->builtin_code) |
| 174 | * { |
| 175 | * case kTfLiteBuiltinAbs: |
| 176 | * return VisitAbsLayer(...); |
| 177 | * ... |
| 178 | * default: |
| 179 | * return kTfLiteError; |
| 180 | * } |
| 181 | */ |
| 182 | return kTfLiteError; |
| 183 | } |
| 184 | |
| 185 | } // armnnDelegate namespace |