blob: 92aed799820034a86ee75c243d3bf61b154dfe5e [file] [log] [blame]
Idriss Chaouchcbf79292023-09-08 11:18:16 +01001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/utility/IgnoreUnused.hpp>
9
10#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14#include <tensorflow/lite/kernels/internal/tensor_ctypes.h>
15#include <tensorflow/lite/schema/schema_generated.h>
16#include <armnn_delegate.hpp>
17
18namespace armnnDelegate
19{
20 TfLiteStatus ValidateBroadcastToOperator(DelegateData& delegateData,
21 TfLiteContext* tfLiteContext,
22 const armnn::TensorInfo& inputInfo,
23 const armnn::TensorInfo& outputInfo,
24 const armnn::BroadcastToDescriptor& descriptor)
25 {
26 bool isSupported = false;
27 FORWARD_LAYER_SUPPORT_FUNC("BROADCAST_TO",
28 tfLiteContext,
29 IsBroadcastToSupported,
30 delegateData.m_Backends,
31 isSupported,
32 armnn::BackendId(),
33 inputInfo,
34 outputInfo,
35 descriptor);
36 return isSupported ? kTfLiteOk : kTfLiteError;
37 }
38
39 TfLiteStatus VisitBroadcastToOperator(DelegateData& delegateData,
40 TfLiteContext* tfLiteContext,
41 TfLiteNode* tfLiteNode,
42 int nodeIndex,
43 int32_t broadcastToOperatorCode)
44 {
45 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
46 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
47
48 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
49
50 // The input contains the data that should be broadcasted
51 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
52 if (IsDynamicTensor(tfLiteInputTensor))
53 {
54 TF_LITE_MAYBE_KERNEL_LOG(
55 tfLiteContext,
56 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
57 broadcastToOperatorCode, nodeIndex);
58 return kTfLiteError;
59 }
60
61 // The shape tensor contains the new shape to be applied on the input
62 const TfLiteTensor& tfLiteShapeTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
63 if (IsDynamicTensor(tfLiteShapeTensor))
64 {
65 TF_LITE_MAYBE_KERNEL_LOG(
66 tfLiteContext,
67 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
68 broadcastToOperatorCode, nodeIndex);
69 return kTfLiteError;
70 }
71
72 // The output tensor
73 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
74 if (IsDynamicTensor(tfLiteOutputTensor))
75 {
76 TF_LITE_MAYBE_KERNEL_LOG(
77 tfLiteContext,
78 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
79 broadcastToOperatorCode, nodeIndex);
80 return kTfLiteError;
81 }
82
83 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
84 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
85
86 auto* shapeData = tflite::GetTensorData<int32_t>(&tfLiteShapeTensor);
87 auto shapeTensorNum = tfLiteShapeTensor.dims->data[0];
88
89 armnn::BroadcastToDescriptor broadcastToDescriptor;
90 broadcastToDescriptor.m_BroadcastToShape = armnn::TensorShape(shapeTensorNum,
91 shapeData);
92
93 // No network pointer indicates that only support for this operator should be checked
94 if (!delegateData.m_Network)
95 {
96 return ValidateBroadcastToOperator(delegateData,
97 tfLiteContext,
98 inputTensorInfo,
99 outputTensorInfo,
100 broadcastToDescriptor);
101 }
102
103 auto layerName = GetLayerName(armnn::LayerType::BroadcastTo, nodeIndex);
104 armnn::IConnectableLayer* layer = delegateData.m_Network->AddBroadcastToLayer(broadcastToDescriptor,
105 layerName.c_str());
106
107 if (layer == nullptr)
108 {
109 return kTfLiteError;
110 }
111
112 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
113
114 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
115 {
116 return kTfLiteError;
117 }
118
119 return Connect(layer, tfLiteNode, delegateData);
120 }
121
122} // namespace armnnDelegate