blob: d797563ab5ece62e0f2a51b5e83c0d4e7993e283 [file] [log] [blame]
Keith Davis0176fd82021-06-01 17:36:32 +01001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Keith Davis0176fd82021-06-01 17:36:32 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "DelegateUtils.hpp"
9
10#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14#include <numeric>
15
16namespace armnnDelegate
17{
18
19TfLiteStatus VisitShapeOperator(DelegateData& delegateData,
20 TfLiteContext* tfLiteContext,
21 TfLiteNode* tfLiteNode,
22 int nodeIndex,
23 int32_t operatorCode)
24{
25 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
26 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
27
28 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
29 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
30 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
31 {
32 return kTfLiteError;
33 }
34
35 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
36 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
37 {
38 return kTfLiteError;
39 }
40
41 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010042 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Keith Davis0176fd82021-06-01 17:36:32 +010043
44 auto* shapeParameters = reinterpret_cast<TfLiteShapeParams*>(tfLiteNode->builtin_data);
45 if ( shapeParameters->out_type != kTfLiteInt32 && shapeParameters->out_type != kTfLiteInt64 )
46 {
47 TF_LITE_MAYBE_KERNEL_LOG(
48 tfLiteContext,
49 "TfLiteArmnnDelegate: output_type data type is not supported in operator #%d node #%d: ",
50 operatorCode, nodeIndex);
51 return kTfLiteError;
52 }
53
54 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +010055 armnn::BackendId setBackend;
Keith Davis0176fd82021-06-01 17:36:32 +010056 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
57 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000058 FORWARD_LAYER_SUPPORT_FUNC("SHAPE",
Keith Davis0176fd82021-06-01 17:36:32 +010059 tfLiteContext,
60 IsShapeSupported,
61 delegateData.m_Backends,
62 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010063 setBackend,
Keith Davis0176fd82021-06-01 17:36:32 +010064 inputTensorInfo,
65 outInfo);
66 };
67
68 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
69 // support for the operator
70 // If supported, VisitShapeOperator will be called again to add the layer to the network as seen further below
71 if (!delegateData.m_Network)
72 {
73 validateFunc(outputTensorInfo, isSupported);
74 return isSupported ? kTfLiteOk : kTfLiteError;
75 }
76
77 // Add a Shape layer
78 armnn::IConnectableLayer* layer = delegateData.m_Network->AddShapeLayer();
Cathal Corbett53837672022-09-01 11:34:37 +010079 layer->SetBackendId(setBackend);
Keith Davis0176fd82021-06-01 17:36:32 +010080 ARMNN_ASSERT(layer != nullptr);
81
82 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
83 outputSlot.SetTensorInfo(outputTensorInfo);
84
Ryan OShea4c231de2023-01-17 15:19:20 +000085 // try to connect the Constant Inputs if there are any
86 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
87 {
88 return kTfLiteError;
89 }
90
Keith Davis0176fd82021-06-01 17:36:32 +010091 // Connect
92 return Connect(layer, tfLiteNode, delegateData);
93}
94
95} // namespace armnnDelegate