blob: 31c6ac3677eedcf5fa041c08481f1f62264e3453 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
James Warda8578102020-11-13 18:05:04 +00008#include "DelegateUtils.hpp"
Finn Williams6f9f9902020-11-13 13:23:15 +00009
Sadik Armagan62483be2020-10-23 17:14:43 +010010#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnDelegate
16{
17
James Warda8578102020-11-13 18:05:04 +000018TfLiteStatus ValidateSoftmaxOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 const armnn::TensorInfo& inputInfo,
21 const armnn::TensorInfo& outputTensorInfo,
22 const armnn::SoftmaxDescriptor& descriptor)
23{
24 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +000025 FORWARD_LAYER_SUPPORT_FUNC("SOFTMAX",
James Warda8578102020-11-13 18:05:04 +000026 tfLiteContext,
27 IsSoftmaxSupported,
28 delegateData.m_Backends,
29 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010030 armnn::BackendId(),
James Warda8578102020-11-13 18:05:04 +000031 inputInfo,
32 outputTensorInfo,
33 descriptor);
34 return isSupported ? kTfLiteOk : kTfLiteError;
35}
36
37
38TfLiteStatus ValidateLogSoftmaxOperator(DelegateData& delegateData,
39 TfLiteContext* tfLiteContext,
40 const armnn::TensorInfo& inputInfo,
41 const armnn::TensorInfo& outputTensorInfo,
42 const armnn::LogSoftmaxDescriptor& descriptor)
43{
44 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +000045 FORWARD_LAYER_SUPPORT_FUNC("LOG_SOFTMAX",
James Warda8578102020-11-13 18:05:04 +000046 tfLiteContext,
47 IsLogSoftmaxSupported,
48 delegateData.m_Backends,
49 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010050 armnn::BackendId(),
James Warda8578102020-11-13 18:05:04 +000051 inputInfo,
52 outputTensorInfo,
53 descriptor);
54 return isSupported ? kTfLiteOk : kTfLiteError;
55}
56
Sadik Armagan62483be2020-10-23 17:14:43 +010057TfLiteStatus VisitSoftmaxOperator(DelegateData& delegateData,
58 TfLiteContext* tfLiteContext,
59 TfLiteNode* tfLiteNode,
60 int nodeIndex,
61 int32_t softmaxOperatorCode)
62{
James Warda8578102020-11-13 18:05:04 +000063 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
64 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Finn Williams6f9f9902020-11-13 13:23:15 +000065
James Warda8578102020-11-13 18:05:04 +000066 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
67 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
68 if (IsDynamicTensor(tfLiteInputTensor))
69 {
70 TF_LITE_MAYBE_KERNEL_LOG(
71 tfLiteContext,
72 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in node #%d: ",
73 nodeIndex);
74 return kTfLiteError;
75 }
76 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
77 if (IsDynamicTensor(tfLiteOutputTensor))
78 {
79 TF_LITE_MAYBE_KERNEL_LOG(
80 tfLiteContext,
81 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in node #%d: ",
82 nodeIndex);
83 return kTfLiteError;
84 }
85
86 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010087 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
James Warda8578102020-11-13 18:05:04 +000088
89
90 if (!delegateData.m_Network)
91 {
92 switch(softmaxOperatorCode)
93 {
94 case kTfLiteBuiltinSoftmax:
95 {
96 armnn::SoftmaxDescriptor descriptor;
97 auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(tfLiteNode->builtin_data);
98 descriptor.m_Beta = params->beta;
99 return ValidateSoftmaxOperator(delegateData,
100 tfLiteContext,
101 inputTensorInfo,
102 outputTensorInfo,
103 descriptor);
104 }
105 case kTfLiteBuiltinLogSoftmax:
106 {
107 armnn::LogSoftmaxDescriptor descriptor;
108 return ValidateLogSoftmaxOperator(delegateData,
109 tfLiteContext,
110 inputTensorInfo,
111 outputTensorInfo,
112 descriptor);
113 }
114 default:
115 return kTfLiteError;
116 }
117 }
118
119 armnn::IConnectableLayer* softmaxLayer = nullptr;
120
121 switch(softmaxOperatorCode)
122 {
123 case kTfLiteBuiltinSoftmax:
124 {
125 armnn::SoftmaxDescriptor descriptor;
126 auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(tfLiteNode->builtin_data);
127 descriptor.m_Beta = params->beta;
128 softmaxLayer = delegateData.m_Network->AddSoftmaxLayer(descriptor);
129 break;
130 }
131 case kTfLiteBuiltinLogSoftmax:
132 {
133 armnn::LogSoftmaxDescriptor descriptor;
134 softmaxLayer = delegateData.m_Network->AddLogSoftmaxLayer(descriptor);
135 break;
136 }
137 default:
138 return kTfLiteError;
139 }
140 ARMNN_ASSERT(softmaxLayer != nullptr);
141
142 armnn::IOutputSlot& outputSlot = softmaxLayer->GetOutputSlot(0);
143 outputSlot.SetTensorInfo(outputTensorInfo);
144
Ryan OShea4c231de2023-01-17 15:19:20 +0000145 // try to connect the Constant Inputs if there are any
146 if(ProcessInputs(softmaxLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
147 {
148 return kTfLiteError;
149 }
150
James Warda8578102020-11-13 18:05:04 +0000151 // Connect
152 return Connect(softmaxLayer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100153}
154
155} // namespace armnnDelegate