blob: 31fe1c945e7905ea39824e4d27e64690eb45eee7 [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Teresa Charlin42362962023-04-28 14:23:33 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12TfLiteStatus ValidateSoftmaxOperator(DelegateData& delegateData,
13 TfLiteOpaqueContext* tfLiteContext,
14 const armnn::TensorInfo& inputInfo,
15 const armnn::TensorInfo& outputTensorInfo,
16 const armnn::SoftmaxDescriptor& descriptor)
17{
18 bool isSupported = false;
19 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("SOFTMAX",
20 tfLiteContext,
21 IsSoftmaxSupported,
22 delegateData.m_Backends,
23 isSupported,
24 armnn::BackendId(),
25 inputInfo,
26 outputTensorInfo,
27 descriptor);
28 return isSupported ? kTfLiteOk : kTfLiteError;
29}
30
31TfLiteStatus ValidateLogSoftmaxOperator(DelegateData& delegateData,
32 TfLiteOpaqueContext* tfLiteContext,
33 const armnn::TensorInfo& inputInfo,
34 const armnn::TensorInfo& outputTensorInfo,
35 const armnn::LogSoftmaxDescriptor& descriptor)
36{
37 bool isSupported = false;
38 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("LOG_SOFTMAX",
39 tfLiteContext,
40 IsLogSoftmaxSupported,
41 delegateData.m_Backends,
42 isSupported,
43 armnn::BackendId(),
44 inputInfo,
45 outputTensorInfo,
46 descriptor);
47 return isSupported ? kTfLiteOk : kTfLiteError;
48}
49
50TfLiteStatus VisitSoftmaxOperator(DelegateData& delegateData,
51 TfLiteOpaqueContext* tfLiteContext,
52 TfLiteOpaqueNode* tfLiteNode,
53 int nodeIndex,
54 int32_t tfliteSoftmaxOperatorCode)
55{
56 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
57 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
58
59 // Gather input indices and use to get input tensor.
60 const int* inputTensors;
61 int numInputs;
62 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
63 {
64 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
65 tfLiteContext,
66 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
67 nodeIndex);
68 return kTfLiteError;
69 }
70
71 const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
72 if (!IsValid(tfLiteContext, tfLiteInputTensor, tfliteSoftmaxOperatorCode, nodeIndex))
73 {
74 return kTfLiteError;
75 }
76
77 // Gather output indices and use to get output tensor.
78 const int* outputTensors;
79 int numOutputs = 0;
80 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
81 {
82 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
83 tfLiteContext,
84 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
85 nodeIndex);
86 return kTfLiteError;
87 }
88
89 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
90 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfliteSoftmaxOperatorCode, nodeIndex))
91 {
92 return kTfLiteError;
93 }
94
95 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
96 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
97
98 if (!delegateData.m_Network)
99 {
100 switch(tfliteSoftmaxOperatorCode)
101 {
102 case kTfLiteBuiltinSoftmax:
103 {
104 armnn::SoftmaxDescriptor descriptor;
105 auto* nodeParams = reinterpret_cast<TfLiteSoftmaxParams*>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
106 descriptor.m_Beta = nodeParams->beta;
107 return ValidateSoftmaxOperator(delegateData,
108 tfLiteContext,
109 inputTensorInfo,
110 outputTensorInfo,
111 descriptor);
112 }
113 case kTfLiteBuiltinLogSoftmax:
114 {
115 armnn::LogSoftmaxDescriptor descriptor;
116 return ValidateLogSoftmaxOperator(delegateData,
117 tfLiteContext,
118 inputTensorInfo,
119 outputTensorInfo,
120 descriptor);
121 }
122 default:
123 return kTfLiteError;
124 }
125 }
126
127 armnn::IConnectableLayer* softmaxLayer = nullptr;
Mike Kellya2806502023-08-03 10:42:11 +0100128 auto layerName = GetName(armnn::LayerType::Softmax, nodeIndex);
129
Teresa Charlin42362962023-04-28 14:23:33 +0100130 switch(tfliteSoftmaxOperatorCode)
131 {
132 case kTfLiteBuiltinSoftmax:
133 {
134 armnn::SoftmaxDescriptor descriptor;
135 auto* nodeParameters = reinterpret_cast<TfLiteSoftmaxParams*>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
136 descriptor.m_Beta = nodeParameters->beta;
Mike Kellya2806502023-08-03 10:42:11 +0100137 softmaxLayer = delegateData.m_Network->AddSoftmaxLayer(descriptor, layerName.c_str());
Teresa Charlin42362962023-04-28 14:23:33 +0100138 break;
139 }
140 case kTfLiteBuiltinLogSoftmax:
141 {
142 armnn::LogSoftmaxDescriptor descriptor;
Mike Kellya2806502023-08-03 10:42:11 +0100143 softmaxLayer = delegateData.m_Network->AddLogSoftmaxLayer(descriptor, layerName.c_str());
Teresa Charlin42362962023-04-28 14:23:33 +0100144 break;
145 }
146 default:
147 return kTfLiteError;
148 }
149
150 ARMNN_ASSERT(softmaxLayer != nullptr);
151 armnn::IOutputSlot& outputSlot = softmaxLayer->GetOutputSlot(0);
152 outputSlot.SetTensorInfo(outputTensorInfo);
153
154 // try to connect the Constant Inputs if there are any
Mike Kellya2806502023-08-03 10:42:11 +0100155 if (ProcessInputs(softmaxLayer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Teresa Charlin42362962023-04-28 14:23:33 +0100156 {
157 return kTfLiteError;
158 }
159
160 // Connect
161 return Connect(softmaxLayer, tfLiteContext, tfLiteNode, delegateData);
162}
163} // namespace armnnOpaqueDelegate