blob: dd36a3d05f3bf5265d5a3ea1f4adbf5ab437eed2 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyan11572322023-03-16 10:17:51 +00008#include <ClassicDelegateUtils.hpp>
Finn Williams6f9f9902020-11-13 13:23:15 +00009
Sadik Armagan62483be2020-10-23 17:14:43 +010010#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnDelegate
16{
17
James Warda8578102020-11-13 18:05:04 +000018TfLiteStatus ValidateSoftmaxOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 const armnn::TensorInfo& inputInfo,
21 const armnn::TensorInfo& outputTensorInfo,
22 const armnn::SoftmaxDescriptor& descriptor)
23{
24 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +000025 FORWARD_LAYER_SUPPORT_FUNC("SOFTMAX",
James Warda8578102020-11-13 18:05:04 +000026 tfLiteContext,
27 IsSoftmaxSupported,
28 delegateData.m_Backends,
29 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010030 armnn::BackendId(),
James Warda8578102020-11-13 18:05:04 +000031 inputInfo,
32 outputTensorInfo,
33 descriptor);
34 return isSupported ? kTfLiteOk : kTfLiteError;
35}
36
37
38TfLiteStatus ValidateLogSoftmaxOperator(DelegateData& delegateData,
39 TfLiteContext* tfLiteContext,
40 const armnn::TensorInfo& inputInfo,
41 const armnn::TensorInfo& outputTensorInfo,
42 const armnn::LogSoftmaxDescriptor& descriptor)
43{
44 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +000045 FORWARD_LAYER_SUPPORT_FUNC("LOG_SOFTMAX",
James Warda8578102020-11-13 18:05:04 +000046 tfLiteContext,
47 IsLogSoftmaxSupported,
48 delegateData.m_Backends,
49 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010050 armnn::BackendId(),
James Warda8578102020-11-13 18:05:04 +000051 inputInfo,
52 outputTensorInfo,
53 descriptor);
54 return isSupported ? kTfLiteOk : kTfLiteError;
55}
56
Sadik Armagan62483be2020-10-23 17:14:43 +010057TfLiteStatus VisitSoftmaxOperator(DelegateData& delegateData,
58 TfLiteContext* tfLiteContext,
59 TfLiteNode* tfLiteNode,
60 int nodeIndex,
61 int32_t softmaxOperatorCode)
62{
James Warda8578102020-11-13 18:05:04 +000063 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
64 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Finn Williams6f9f9902020-11-13 13:23:15 +000065
James Warda8578102020-11-13 18:05:04 +000066 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
67 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
68 if (IsDynamicTensor(tfLiteInputTensor))
69 {
70 TF_LITE_MAYBE_KERNEL_LOG(
71 tfLiteContext,
72 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in node #%d: ",
73 nodeIndex);
74 return kTfLiteError;
75 }
76 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
77 if (IsDynamicTensor(tfLiteOutputTensor))
78 {
79 TF_LITE_MAYBE_KERNEL_LOG(
80 tfLiteContext,
81 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in node #%d: ",
82 nodeIndex);
83 return kTfLiteError;
84 }
85
86 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010087 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
James Warda8578102020-11-13 18:05:04 +000088
89
90 if (!delegateData.m_Network)
91 {
92 switch(softmaxOperatorCode)
93 {
94 case kTfLiteBuiltinSoftmax:
95 {
96 armnn::SoftmaxDescriptor descriptor;
97 auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(tfLiteNode->builtin_data);
98 descriptor.m_Beta = params->beta;
99 return ValidateSoftmaxOperator(delegateData,
100 tfLiteContext,
101 inputTensorInfo,
102 outputTensorInfo,
103 descriptor);
104 }
105 case kTfLiteBuiltinLogSoftmax:
106 {
107 armnn::LogSoftmaxDescriptor descriptor;
108 return ValidateLogSoftmaxOperator(delegateData,
109 tfLiteContext,
110 inputTensorInfo,
111 outputTensorInfo,
112 descriptor);
113 }
114 default:
115 return kTfLiteError;
116 }
117 }
118
Mike Kelly07169c82023-08-02 13:23:09 +0100119 auto layerName = GetLayerName(armnn::LayerType::Softmax, nodeIndex);
James Warda8578102020-11-13 18:05:04 +0000120 armnn::IConnectableLayer* softmaxLayer = nullptr;
121
122 switch(softmaxOperatorCode)
123 {
124 case kTfLiteBuiltinSoftmax:
125 {
126 armnn::SoftmaxDescriptor descriptor;
127 auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(tfLiteNode->builtin_data);
128 descriptor.m_Beta = params->beta;
Mike Kelly07169c82023-08-02 13:23:09 +0100129 softmaxLayer = delegateData.m_Network->AddSoftmaxLayer(descriptor, layerName.c_str());
James Warda8578102020-11-13 18:05:04 +0000130 break;
131 }
132 case kTfLiteBuiltinLogSoftmax:
133 {
134 armnn::LogSoftmaxDescriptor descriptor;
Mike Kelly07169c82023-08-02 13:23:09 +0100135 softmaxLayer = delegateData.m_Network->AddLogSoftmaxLayer(descriptor, layerName.c_str());
James Warda8578102020-11-13 18:05:04 +0000136 break;
137 }
138 default:
139 return kTfLiteError;
140 }
141 ARMNN_ASSERT(softmaxLayer != nullptr);
142
143 armnn::IOutputSlot& outputSlot = softmaxLayer->GetOutputSlot(0);
144 outputSlot.SetTensorInfo(outputTensorInfo);
145
Ryan OShea4c231de2023-01-17 15:19:20 +0000146 // try to connect the Constant Inputs if there are any
Mike Kelly07169c82023-08-02 13:23:09 +0100147 if (ProcessInputs(softmaxLayer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Ryan OShea4c231de2023-01-17 15:19:20 +0000148 {
149 return kTfLiteError;
150 }
151
James Warda8578102020-11-13 18:05:04 +0000152 // Connect
153 return Connect(softmaxLayer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100154}
155
156} // namespace armnnDelegate