blob: ecc545d7c57c9ceda3739e2ff746c856fbe9625c [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
John Mcloughlin0422cf22023-04-27 16:55:00 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12
13TfLiteStatus VisitShapeOperator(DelegateData& delegateData,
14 TfLiteOpaqueContext* tfLiteContext,
15 TfLiteOpaqueNode* tfLiteNode,
16 int nodeIndex,
17 int32_t operatorCode)
18{
19 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
20 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
21
22 // Gather input indices and use to get input tensor.
23 auto numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode);
24 const int* inputTensors;
25 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
26 {
27 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
28 tfLiteContext,
29 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
30 nodeIndex);
31 return kTfLiteError;
32 }
33
34 const TfLiteOpaqueTensor* tfLiteInputTensor =
35 TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
36 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
37 {
38 return kTfLiteError;
39 }
40
41 // Gather output indices and use to get output tensors.
42 int numOutputs = 0;
43 const int* outputTensors;
44 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
45 {
46 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
47 tfLiteContext,
48 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
49 nodeIndex);
50 return kTfLiteError;
51 }
52
53 const TfLiteOpaqueTensor* tfLiteOutputTensor =
54 TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
55 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
56 {
57 return kTfLiteError;
58 }
59
60 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
61 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
62
63 auto* shapeParameters = reinterpret_cast<TfLiteShapeParams*>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
64 if ( shapeParameters->out_type != kTfLiteInt32 && shapeParameters->out_type != kTfLiteInt64 )
65 {
66 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
67 tfLiteContext,
68 "TfLiteArmnnOpaqueDelegate: output_type data type is not supported in operator #%d node #%d: ",
69 operatorCode, nodeIndex);
70 return kTfLiteError;
71 }
72
73 bool isSupported = false;
74 armnn::BackendId setBackend;
75 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
76 {
77 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("SHAPE",
78 tfLiteContext,
79 IsShapeSupported,
80 delegateData.m_Backends,
81 isSupported,
82 setBackend,
83 inputTensorInfo,
84 outInfo);
85 };
86
87 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
88 // support for the operator
89 // If supported, VisitShapeOperator will be called again to add the layer to the network as seen further below
90 if (!delegateData.m_Network)
91 {
92 validateFunc(outputTensorInfo, isSupported);
93 return isSupported ? kTfLiteOk : kTfLiteError;
94 }
95
96 // Add a Shape layer
97 armnn::IConnectableLayer* layer = delegateData.m_Network->AddShapeLayer();
98 layer->SetBackendId(setBackend);
99 ARMNN_ASSERT(layer != nullptr);
100
101 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
102 outputSlot.SetTensorInfo(outputTensorInfo);
103
104 // try to connect the Constant Inputs if there are any
105 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
106 {
107 return kTfLiteError;
108 }
109
110 // Connect
111 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
112}
113
114} // namespace armnnOpaqueDelegate