blob: 9f15a4f739300c9e6dc40aef860f1ddbc8f36384 [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
John Mcloughlin0422cf22023-04-27 16:55:00 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12
13TfLiteStatus VisitShapeOperator(DelegateData& delegateData,
14 TfLiteOpaqueContext* tfLiteContext,
15 TfLiteOpaqueNode* tfLiteNode,
16 int nodeIndex,
17 int32_t operatorCode)
18{
19 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
20 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
21
22 // Gather input indices and use to get input tensor.
23 auto numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode);
24 const int* inputTensors;
25 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
26 {
27 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
28 tfLiteContext,
29 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
30 nodeIndex);
31 return kTfLiteError;
32 }
33
Teresa Charlin86b03572023-04-28 13:19:12 +010034 const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
John Mcloughlin0422cf22023-04-27 16:55:00 +010035 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
36 {
37 return kTfLiteError;
38 }
39
40 // Gather output indices and use to get output tensors.
41 int numOutputs = 0;
42 const int* outputTensors;
43 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
44 {
45 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
46 tfLiteContext,
47 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
48 nodeIndex);
49 return kTfLiteError;
50 }
51
Teresa Charlin86b03572023-04-28 13:19:12 +010052 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
John Mcloughlin0422cf22023-04-27 16:55:00 +010053 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
54 {
55 return kTfLiteError;
56 }
57
58 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
59 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
60
61 auto* shapeParameters = reinterpret_cast<TfLiteShapeParams*>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
Mike Kellya2806502023-08-03 10:42:11 +010062 if (shapeParameters->out_type != kTfLiteInt32 && shapeParameters->out_type != kTfLiteInt64)
John Mcloughlin0422cf22023-04-27 16:55:00 +010063 {
64 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
65 tfLiteContext,
66 "TfLiteArmnnOpaqueDelegate: output_type data type is not supported in operator #%d node #%d: ",
67 operatorCode, nodeIndex);
68 return kTfLiteError;
69 }
70
71 bool isSupported = false;
72 armnn::BackendId setBackend;
73 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
74 {
75 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("SHAPE",
76 tfLiteContext,
77 IsShapeSupported,
78 delegateData.m_Backends,
79 isSupported,
80 setBackend,
81 inputTensorInfo,
82 outInfo);
83 };
84
85 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
86 // support for the operator
87 // If supported, VisitShapeOperator will be called again to add the layer to the network as seen further below
88 if (!delegateData.m_Network)
89 {
90 validateFunc(outputTensorInfo, isSupported);
91 return isSupported ? kTfLiteOk : kTfLiteError;
92 }
93
94 // Add a Shape layer
Mike Kellya2806502023-08-03 10:42:11 +010095 auto layerName = GetName(armnn::LayerType::Shape, nodeIndex);
96 armnn::IConnectableLayer* layer = delegateData.m_Network->AddShapeLayer(layerName.c_str());
John Mcloughlin0422cf22023-04-27 16:55:00 +010097 layer->SetBackendId(setBackend);
98 ARMNN_ASSERT(layer != nullptr);
99
100 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
101 outputSlot.SetTensorInfo(outputTensorInfo);
102
103 // try to connect the Constant Inputs if there are any
Mike Kellya2806502023-08-03 10:42:11 +0100104 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
John Mcloughlin0422cf22023-04-27 16:55:00 +0100105 {
106 return kTfLiteError;
107 }
108
109 // Connect
110 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
111}
112
113} // namespace armnnOpaqueDelegate