blob: 379587546f249128d7506a3ebc9fb75266613ba9 [file] [log] [blame]
Idriss Chaouchcbf79292023-09-08 11:18:16 +01001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12 TfLiteStatus ValidateBroadcastToOperator(DelegateData& delegateData,
13 TfLiteOpaqueContext *tfLiteContext,
14 const armnn::TensorInfo& inputInfo,
15 const armnn::TensorInfo& outputInfo,
16 const armnn::BroadcastToDescriptor& descriptor)
17 {
18 bool isSupported = false;
19 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("BROADCAST_TO",
20 tfLiteContext,
21 IsBroadcastToSupported,
22 delegateData.m_Backends,
23 isSupported,
24 armnn::BackendId(),
25 inputInfo,
26 outputInfo,
27 descriptor);
28 return isSupported ? kTfLiteOk : kTfLiteError;
29 }
30
31 TfLiteStatus VisitBroadcastToOperator(DelegateData& delegateData,
32 TfLiteOpaqueContext* tfLiteContext,
33 TfLiteOpaqueNode* tfLiteNode,
34 int nodeIndex,
35 int32_t broadcastToOperatorCode)
36 {
37 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
38 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
39
40 // Gather input tensors
41 auto numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode);
42 const int* inputTensors;
43 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
44 {
45 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
46 tfLiteContext,
47 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
48 nodeIndex);
49 return kTfLiteError;
50 }
51
52 // Gather output tensors
53 int numOutputs = 0;
54 const int* outputTensors;
55 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors,
56 &numOutputs) != kTfLiteOk)
57 {
58 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
59 tfLiteContext,
60 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
61 nodeIndex);
62 return kTfLiteError;
63 }
64
65 // The input contains the data
66 const TfLiteOpaqueTensor* tfLiteInputTensor =
67 TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
68 if (IsDynamicTensor(tfLiteInputTensor))
69 {
70 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
71 tfLiteContext,
72 "TfLiteArmnnOpaqueDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
73 broadcastToOperatorCode, nodeIndex);
74 return kTfLiteError;
75 }
76
77 // The shape tensor
78 const TfLiteOpaqueTensor* tfLiteShapeTensor =
79 TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);;
80 if (IsDynamicTensor(tfLiteShapeTensor))
81 {
82 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
83 tfLiteContext,
84 "TfLiteArmnnOpaqueDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
85 broadcastToOperatorCode, nodeIndex);
86 return kTfLiteError;
87 }
88
89 // The output tensor
90 const TfLiteOpaqueTensor* tfLiteOutputTensor =
91 TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
92 if (IsDynamicTensor(tfLiteOutputTensor))
93 {
94 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
95 tfLiteContext,
96 "TfLiteArmnnOpaqueDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
97 broadcastToOperatorCode, nodeIndex);
98 return kTfLiteError;
99 }
100
101 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
102 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor,
103 true);
104
105 auto* shapeData = static_cast<int32_t*>(TfLiteOpaqueTensorData(tfLiteShapeTensor));
106 int32_t shapeTensorNum = TfLiteOpaqueTensorDim(tfLiteShapeTensor, 0);
107
108 armnn::BroadcastToDescriptor broadcastToDescriptor;
109 broadcastToDescriptor.m_BroadcastToShape = armnn::TensorShape(shapeTensorNum,
110 shapeData);
111
112 // No network pointer indicates that only support for this operator should be checked
113 if (!delegateData.m_Network)
114 {
115 return ValidateBroadcastToOperator(delegateData,
116 tfLiteContext,
117 inputTensorInfo,
118 outputTensorInfo,
119 broadcastToDescriptor);
120 }
121
122 std::string layerName("BroadcastTo");
123 armnn::IConnectableLayer* layer = delegateData.m_Network->AddBroadcastToLayer(broadcastToDescriptor,
124 layerName.c_str());
125
126 if (layer == nullptr)
127 {
128 return kTfLiteError;
129 }
130
131 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
132
133 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
134 {
135 return kTfLiteError;
136 }
137
138 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
139 }
140
141} // namespace armnnOpaqueDelegate