blob: 7dd8561de4a959ebacb8df5caded21f2a9993f7b [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Ryan OSheaa37ccb02023-04-11 10:54:07 +01005#pragma once
6
7#include <armnn/utility/IgnoreUnused.hpp>
8
9#include "OpaqueDelegateUtils.hpp"
10
11#include <tensorflow/lite/builtin_ops.h>
12#include <tensorflow/lite/c/builtin_op_data.h>
13#include <tensorflow/lite/c/common.h>
14#include <tensorflow/lite/minimal_logging.h>
15#include <numeric>
16
17namespace armnnOpaqueDelegate
18{
19
20TfLiteStatus VisitCastOperator(DelegateData& delegateData,
21 TfLiteOpaqueContext* tfLiteContext,
22 TfLiteOpaqueNode* tfLiteNode,
23 int nodeIndex,
24 int32_t operatorCode)
25{
26 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
27 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
28 int numInputs = 0;
29 const int* inputTensors;
30 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
31 {
32 return kTfLiteError;
33 }
34
35 // This layer only has 1 input, so we can directly assign tensor[0] to a new opaque tensor
36 const TfLiteOpaqueTensor*
37 tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[numInputs-1]);
38 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
39 {
40 return kTfLiteError;
41 }
42
43 int numOutputs = 0;
44 const int* outputTensors;
45 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
46 {
47 return kTfLiteError;
48 }
49
50 // This layer only has 1 output, so we can directly assign tensor[0] to a new opaque tensor
51 const TfLiteOpaqueTensor*
52 tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[numOutputs-1]);
53 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
54 {
55 return kTfLiteError;
56 }
57
58 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
59 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
60
61 bool isSupported = false;
62 armnn::BackendId setBackend;
63 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported) {
64 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("CAST",
65 tfLiteContext,
66 IsCastSupported,
67 delegateData.m_Backends,
68 isSupported,
69 setBackend,
70 inputTensorInfo,
71 outInfo);
72 };
73
74 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
75 // support for the operator
76 // If supported, VisitCastOperator will be called again to add the layer to the network as seen further below
77 if (!delegateData.m_Network)
78 {
79 validateFunc(outputTensorInfo, isSupported);
80 return isSupported ? kTfLiteOk : kTfLiteError;
81 }
82
83 // Add a Cast layer
84 armnn::IConnectableLayer* layer = delegateData.m_Network->AddCastLayer();
85 layer->SetBackendId(setBackend);
86 ARMNN_ASSERT(layer != nullptr);
87
88 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
89 outputSlot.SetTensorInfo(outputTensorInfo);
90
91 // try to connect the Constant Inputs if there are any
92 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk)
93 {
94 return kTfLiteError;
95 }
96
97 // Connect
98 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
99}
100}