blob: fe272555909a177d2fa135b4c7fbd806e5de859a [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Ryan OShea59f8f652023-05-11 20:37:53 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12
13 TfLiteStatus VisitFillOperator(DelegateData& delegateData,
14 TfLiteOpaqueContext* tfLiteContext,
15 TfLiteOpaqueNode* tfLiteNode,
16 int nodeIndex,
17 int32_t tfLiteFillOperatorCode)
18 {
19 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
20
21 switch(tfLiteFillOperatorCode)
22 {
23 case kTfLiteBuiltinFill:
24 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
25 break;
26 default:
27 return kTfLiteError;
28 }
29
30 // Inputs
31 int numInputs = 0;
32 const int* inputTensors;
33 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
34 {
35 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
36 tfLiteContext,
37 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
38 nodeIndex);
39 return kTfLiteError;
40 }
41
42 const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext,
43 inputTensors[0]);
44 if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteFillOperatorCode, nodeIndex))
45 {
46 return kTfLiteError;
47 }
48
49 const TfLiteOpaqueTensor* tfLiteFillTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext,
50 inputTensors[1]);
51 if (!IsValid(tfLiteContext, tfLiteFillTensor, tfLiteFillOperatorCode, nodeIndex))
52 {
53 return kTfLiteError;
54 }
55
56 int numOutputs = 0;
57 const int* outputTensors;
58 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
59 {
60 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
61 tfLiteContext,
62 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
63 nodeIndex);
64 return kTfLiteError;
65 }
66
67 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext,
68 outputTensors[0]);
69 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteFillOperatorCode, nodeIndex))
70 {
71 return kTfLiteError;
72 }
73
74 armnn::TensorInfo inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
75 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
76
77 armnn::FillDescriptor descriptor;
78 switch (TfLiteOpaqueTensorType(tfLiteFillTensor))
79 {
80 case kTfLiteFloat32:
81 descriptor.m_Value = *static_cast<float*>(TfLiteOpaqueTensorData(tfLiteFillTensor));
82 break;
83 case kTfLiteInt32:
84 descriptor.m_Value = *static_cast<int32_t*>(TfLiteOpaqueTensorData(tfLiteFillTensor));
85 break;
86 default:
87 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
88 tfLiteContext,
89 "TfLiteArmnnOpaqueDelegate: FILL value data type is not supported in operator #%d node #%d: ",
90 tfLiteFillOperatorCode, nodeIndex);
91 return kTfLiteError;
92 }
93
94 bool isSupported = false;
95 armnn::BackendId setBackend;
96 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
97 {
98 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("FILL",
99 tfLiteContext,
100 IsFillSupported,
101 delegateData.m_Backends,
102 isSupported,
103 setBackend,
104 inputTensorInfo,
105 outInfo,
106 descriptor);
107 };
108
109 if (!delegateData.m_Network)
110 {
111 validateFunc(outputTensorInfo, isSupported);
112 return isSupported ? kTfLiteOk : kTfLiteError;
113 }
114
Mike Kellya2806502023-08-03 10:42:11 +0100115 auto layerName = GetName(armnn::LayerType::Fill, nodeIndex);
116 armnn::IConnectableLayer* layer = delegateData.m_Network->AddFillLayer(descriptor, layerName.c_str());
Ryan OShea59f8f652023-05-11 20:37:53 +0100117 layer->SetBackendId(setBackend);
118 ARMNN_ASSERT(layer != nullptr);
119
120 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
121 outputSlot.SetTensorInfo(outputTensorInfo);
122
123 auto inputsTensorsProcess = ProcessInputs(layer,
124 delegateData,
125 tfLiteContext,
Mike Kellya2806502023-08-03 10:42:11 +0100126 tfLiteNode,
127 nodeIndex);
Ryan OShea59f8f652023-05-11 20:37:53 +0100128 if (inputsTensorsProcess == kTfLiteError)
129 {
130 return inputsTensorsProcess;
131 }
132
133 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
134 }
135
136} // namespace armnnDelegate