blob: 4064b6361ceb4aeb5eea390a754d5e2a7e1eda09 [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Matthew Sloyan48ec8132023-04-27 17:04:47 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9#include <SharedFunctions.hpp>
10
11namespace armnnOpaqueDelegate
12{
13
14TfLiteStatus VisitFloorOperator(DelegateData& delegateData,
15 TfLiteOpaqueContext* tfLiteContext,
16 TfLiteOpaqueNode* tfLiteNode,
17 int nodeIndex,
18 int32_t operatorCode)
19{
20 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
21 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
22
23 // Gather input indices and use to get input tensor.
24 int numInputs = 0;
25 const int* inputTensors;
26 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
27 {
28 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
29 tfLiteContext,
30 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
31 nodeIndex);
32 return kTfLiteError;
33 }
34
35 // Use input indices to get input tensors.
36 const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
37 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
38 {
39 return kTfLiteError;
40 }
41
42 // Gather output indices and use to get output tensors.
43 int numOutputs = 0;
44 const int* outputTensors;
45 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
46 {
47 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
48 tfLiteContext,
49 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
50 nodeIndex);
51 return kTfLiteError;
52 }
53
54 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
55 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
56 {
57 return kTfLiteError;
58 }
59
60 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
61 // NOTE: looks like the outputTensorInfo is the only thing that is required for the case
62 // where we are adding the floor layer so maybe move the other stuff inside the
63 // if !delegateData block for efficiency.
64 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
65
66 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
67 // support for the operator
68 // If supported, VisitFloorOperator will be called again to add the layer to the network as seen further below
69 if (!delegateData.m_Network)
70 {
71 return ValidateFloorOperator(delegateData, tfLiteContext, inputTensorInfo, outputTensorInfo);
72 }
73
74 // Add a Floor layer
Mike Kellya2806502023-08-03 10:42:11 +010075 auto layerName = GetName(armnn::LayerType::Floor, nodeIndex);
76 armnn::IConnectableLayer* layer = delegateData.m_Network->AddFloorLayer(layerName.c_str());
Matthew Sloyan48ec8132023-04-27 17:04:47 +010077 ARMNN_ASSERT(layer != nullptr);
78
79 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
80 outputSlot.SetTensorInfo(outputTensorInfo);
81
82 // try to connect the Constant Inputs if there are any
Mike Kellya2806502023-08-03 10:42:11 +010083 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Matthew Sloyan48ec8132023-04-27 17:04:47 +010084 {
85 return kTfLiteError;
86 }
87
88 // Connect
89 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
90}
91
92} // namespace armnnOpaqueDelegate