blob: 5a05232e3bc3e60f2f12c1e1eecda796ffc8e246 [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Teresa Charlinecebb0f2023-04-27 21:37:56 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12
13TfLiteStatus VisitPackOperator(DelegateData& delegateData,
14 TfLiteOpaqueContext* tfLiteContext,
15 TfLiteOpaqueNode* tfLiteNode,
16 int nodeIndex,
17 int32_t tfLitePackOperatorCode)
18{
19 // Check Inputs
20 auto numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode);
21 if (numInputs < 1)
22 {
23 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
24 tfLiteContext,
25 "TfLiteArmnnOpaqueDelegate: Must have at least one input in (%d != %d) in node #%d",
26 1,
27 numInputs,
28 nodeIndex);
29 return kTfLiteError;
30 }
31
32 // Gather input indices and use to get input tensors.
33 const int* inputTensors;
34 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
35 {
36 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
37 tfLiteContext,
38 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
39 nodeIndex);
40 return kTfLiteError;
41 }
42
43 // Validate all inputs and get TensorInfo
44 std::vector<armnn::TensorInfo> inputTensorInfos;
45 for (int i = 0; i < numInputs; ++i)
46 {
47 const TfLiteOpaqueTensor* inputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[i]);
48 if (!IsValid(tfLiteContext, inputTensor, tfLitePackOperatorCode, nodeIndex))
49 {
50 return kTfLiteError;
51 }
52
53 armnn::TensorInfo inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(inputTensor);
54 inputTensorInfos.emplace_back(inputTensorInfo);
55 }
56
57 // Convert inputTensorInfos to const armnn::TensorInfo* type for FORWARD_LAYER_OPAQUE_SUPPORT_FUNC.
58 std::vector<const armnn::TensorInfo*> inputConstTensorInfos;
59 std::transform(inputTensorInfos.begin(),
60 inputTensorInfos.end(),
61 std::back_inserter(inputConstTensorInfos),
62 [](armnn::TensorInfo& t)->const armnn::TensorInfo*{ return &t; });
63
64 // Check outputs
65 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
66
67 // Gather output indices and use to get output tensor.
68 const int* outputTensors;
69 int numOutputs;
70 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
71 {
72 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
73 tfLiteContext,
74 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
75 nodeIndex);
76 return kTfLiteError;
77 }
78
79 // Validate the output and get TensorInfo
80 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
81 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLitePackOperatorCode, nodeIndex))
82 {
83 return kTfLiteError;
84 }
85 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
86
87 armnn::StackDescriptor desc;
88 desc.m_NumInputs = static_cast<uint32_t>(numInputs);
89
90 // Get axis from TfLite parameters
91 auto* tfLiteNodeParameters = reinterpret_cast<TfLitePackParams*>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
92 auto axis = tfLiteNodeParameters->axis;
93 desc.m_Axis = NonNegative(axis, nodeIndex);
94
95 // Use the tensor shape of the first input as the "correct" input shape in the descriptor
96 desc.m_InputShape = inputTensorInfos[0].GetShape();
97
98 // Check if supported
99 bool isSupported = false;
100 armnn::BackendId setBackend;
101 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
102 {
103 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("STACK",
104 tfLiteContext,
105 IsStackSupported,
106 delegateData.m_Backends,
107 isSupported,
108 setBackend,
109 inputConstTensorInfos,
110 outputTensorInfo,
111 desc);
112 };
113
114 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
115 // support for the operator
116 // If supported, VisitPackOperator will be called again to add the layer to the network as seen below
117 if (!delegateData.m_Network)
118 {
119 validateFunc(outputTensorInfo, isSupported);
120 return isSupported ? kTfLiteOk : kTfLiteError;
121 }
122
123 // The TfLite Pack operator is equivalent to the ArmNN Stack operator
Mike Kellya2806502023-08-03 10:42:11 +0100124 auto layerName = GetName(armnn::LayerType::Stack, nodeIndex);
125 armnn::IConnectableLayer* layer = delegateData.m_Network->AddStackLayer(desc, layerName.c_str());
Teresa Charlinecebb0f2023-04-27 21:37:56 +0100126 layer->SetBackendId(setBackend);
127 ARMNN_ASSERT(layer != nullptr);
128
129 // Connect the Constant Inputs
130 auto inputsTensorsProcess = ProcessInputs(layer,
131 delegateData,
132 tfLiteContext,
Mike Kellya2806502023-08-03 10:42:11 +0100133 tfLiteNode,
134 nodeIndex);
Teresa Charlinecebb0f2023-04-27 21:37:56 +0100135 if (inputsTensorsProcess == kTfLiteError)
136 {
137 return inputsTensorsProcess;
138 }
139
140 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
141 outputSlot.SetTensorInfo(outputTensorInfo);
142
143 // Connect
144 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
145}
146
147} // namespace armnnOpaqueDelegate