blob: 016af11e212d58db7697a7ce293b5c7760d65b1e [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Jim Flynn4b2f3472021-10-13 21:20:07 +01008#include "SharedFunctions.hpp"
9
Sadik Armagan62483be2020-10-23 17:14:43 +010010#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnDelegate
16{
17
18TfLiteStatus VisitFloorOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 TfLiteNode* tfLiteNode,
21 int nodeIndex,
22 int32_t operatorCode)
23{
Sadik Armagan788e2c62021-02-10 16:26:44 +000024 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Finn Williams6f9f9902020-11-13 13:23:15 +000026
Sadik Armagan788e2c62021-02-10 16:26:44 +000027 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
28 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
29 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
30 {
31 return kTfLiteError;
32 }
33
34 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
35 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
36 {
37 return kTfLiteError;
38 }
39
40 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Jim Flynn4b2f3472021-10-13 21:20:07 +010041 // NOTE: looks like the outputTensorInfo is the only thing that is required for the case
42 // where we are adding the floor layer so maybe move the other stuff inside the
43 // if !delegateData block for efficiency.
Sadik Armagan788e2c62021-02-10 16:26:44 +000044 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
45
Sadik Armagan788e2c62021-02-10 16:26:44 +000046 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
47 // support for the operator
48 // If supported, VisitFloorOperator will be called again to add the layer to the network as seen further below
49 if (!delegateData.m_Network)
50 {
Jim Flynn4b2f3472021-10-13 21:20:07 +010051 return ValidateFloorOperator(delegateData, tfLiteContext, inputTensorInfo, outputTensorInfo);
Sadik Armagan788e2c62021-02-10 16:26:44 +000052 }
53
54 // Add a Floor layer
55 armnn::IConnectableLayer* layer = delegateData.m_Network->AddFloorLayer();
56 ARMNN_ASSERT(layer != nullptr);
57
58 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
59 outputSlot.SetTensorInfo(outputTensorInfo);
60
61 // Connect
62 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +010063}
64
65} // namespace armnnDelegate