blob: 51aa8f143bac2cac7540891356ef75c6da47c8e9 [file] [log] [blame]
Keith Davis0176fd82021-06-01 17:36:32 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "DelegateUtils.hpp"
9
10#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14#include <numeric>
15
16namespace armnnDelegate
17{
18
19TfLiteStatus VisitShapeOperator(DelegateData& delegateData,
20 TfLiteContext* tfLiteContext,
21 TfLiteNode* tfLiteNode,
22 int nodeIndex,
23 int32_t operatorCode)
24{
25 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
26 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
27
28 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
29 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
30 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
31 {
32 return kTfLiteError;
33 }
34
35 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
36 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
37 {
38 return kTfLiteError;
39 }
40
41 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
42 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
43
44 auto* shapeParameters = reinterpret_cast<TfLiteShapeParams*>(tfLiteNode->builtin_data);
45 if ( shapeParameters->out_type != kTfLiteInt32 && shapeParameters->out_type != kTfLiteInt64 )
46 {
47 TF_LITE_MAYBE_KERNEL_LOG(
48 tfLiteContext,
49 "TfLiteArmnnDelegate: output_type data type is not supported in operator #%d node #%d: ",
50 operatorCode, nodeIndex);
51 return kTfLiteError;
52 }
53
54 bool isSupported = false;
55 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
56 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000057 FORWARD_LAYER_SUPPORT_FUNC("SHAPE",
Keith Davis0176fd82021-06-01 17:36:32 +010058 tfLiteContext,
59 IsShapeSupported,
60 delegateData.m_Backends,
61 isSupported,
62 inputTensorInfo,
63 outInfo);
64 };
65
66 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
67 // support for the operator
68 // If supported, VisitShapeOperator will be called again to add the layer to the network as seen further below
69 if (!delegateData.m_Network)
70 {
71 validateFunc(outputTensorInfo, isSupported);
72 return isSupported ? kTfLiteOk : kTfLiteError;
73 }
74
75 // Add a Shape layer
76 armnn::IConnectableLayer* layer = delegateData.m_Network->AddShapeLayer();
77 ARMNN_ASSERT(layer != nullptr);
78
79 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
80 outputSlot.SetTensorInfo(outputTensorInfo);
81
82 // Connect
83 return Connect(layer, tfLiteNode, delegateData);
84}
85
86} // namespace armnnDelegate