blob: 87927616ffdb5094b465a89f3d2d2a8fa8429d1d [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Teresa Charlin42362962023-04-28 14:23:33 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12TfLiteStatus ValidateSoftmaxOperator(DelegateData& delegateData,
13 TfLiteOpaqueContext* tfLiteContext,
14 const armnn::TensorInfo& inputInfo,
15 const armnn::TensorInfo& outputTensorInfo,
16 const armnn::SoftmaxDescriptor& descriptor)
17{
18 bool isSupported = false;
19 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("SOFTMAX",
20 tfLiteContext,
21 IsSoftmaxSupported,
22 delegateData.m_Backends,
23 isSupported,
24 armnn::BackendId(),
25 inputInfo,
26 outputTensorInfo,
27 descriptor);
28 return isSupported ? kTfLiteOk : kTfLiteError;
29}
30
31TfLiteStatus ValidateLogSoftmaxOperator(DelegateData& delegateData,
32 TfLiteOpaqueContext* tfLiteContext,
33 const armnn::TensorInfo& inputInfo,
34 const armnn::TensorInfo& outputTensorInfo,
35 const armnn::LogSoftmaxDescriptor& descriptor)
36{
37 bool isSupported = false;
38 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("LOG_SOFTMAX",
39 tfLiteContext,
40 IsLogSoftmaxSupported,
41 delegateData.m_Backends,
42 isSupported,
43 armnn::BackendId(),
44 inputInfo,
45 outputTensorInfo,
46 descriptor);
47 return isSupported ? kTfLiteOk : kTfLiteError;
48}
49
50TfLiteStatus VisitSoftmaxOperator(DelegateData& delegateData,
51 TfLiteOpaqueContext* tfLiteContext,
52 TfLiteOpaqueNode* tfLiteNode,
53 int nodeIndex,
54 int32_t tfliteSoftmaxOperatorCode)
55{
56 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
57 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
58
59 // Gather input indices and use to get input tensor.
60 const int* inputTensors;
61 int numInputs;
62 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
63 {
64 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
65 tfLiteContext,
66 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
67 nodeIndex);
68 return kTfLiteError;
69 }
70
71 const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
72 if (!IsValid(tfLiteContext, tfLiteInputTensor, tfliteSoftmaxOperatorCode, nodeIndex))
73 {
74 return kTfLiteError;
75 }
76
77 // Gather output indices and use to get output tensor.
78 const int* outputTensors;
79 int numOutputs = 0;
80 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
81 {
82 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
83 tfLiteContext,
84 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
85 nodeIndex);
86 return kTfLiteError;
87 }
88
89 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
90 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfliteSoftmaxOperatorCode, nodeIndex))
91 {
92 return kTfLiteError;
93 }
94
95 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
96 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
97
98 if (!delegateData.m_Network)
99 {
100 switch(tfliteSoftmaxOperatorCode)
101 {
102 case kTfLiteBuiltinSoftmax:
103 {
104 armnn::SoftmaxDescriptor descriptor;
105 auto* nodeParams = reinterpret_cast<TfLiteSoftmaxParams*>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
106 descriptor.m_Beta = nodeParams->beta;
107 return ValidateSoftmaxOperator(delegateData,
108 tfLiteContext,
109 inputTensorInfo,
110 outputTensorInfo,
111 descriptor);
112 }
113 case kTfLiteBuiltinLogSoftmax:
114 {
115 armnn::LogSoftmaxDescriptor descriptor;
116 return ValidateLogSoftmaxOperator(delegateData,
117 tfLiteContext,
118 inputTensorInfo,
119 outputTensorInfo,
120 descriptor);
121 }
122 default:
123 return kTfLiteError;
124 }
125 }
126
127 armnn::IConnectableLayer* softmaxLayer = nullptr;
128 switch(tfliteSoftmaxOperatorCode)
129 {
130 case kTfLiteBuiltinSoftmax:
131 {
132 armnn::SoftmaxDescriptor descriptor;
133 auto* nodeParameters = reinterpret_cast<TfLiteSoftmaxParams*>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
134 descriptor.m_Beta = nodeParameters->beta;
135 softmaxLayer = delegateData.m_Network->AddSoftmaxLayer(descriptor);
136 break;
137 }
138 case kTfLiteBuiltinLogSoftmax:
139 {
140 armnn::LogSoftmaxDescriptor descriptor;
141 softmaxLayer = delegateData.m_Network->AddLogSoftmaxLayer(descriptor);
142 break;
143 }
144 default:
145 return kTfLiteError;
146 }
147
148 ARMNN_ASSERT(softmaxLayer != nullptr);
149 armnn::IOutputSlot& outputSlot = softmaxLayer->GetOutputSlot(0);
150 outputSlot.SetTensorInfo(outputTensorInfo);
151
152 // try to connect the Constant Inputs if there are any
153 if(ProcessInputs(softmaxLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
154 {
155 return kTfLiteError;
156 }
157
158 // Connect
159 return Connect(softmaxLayer, tfLiteContext, tfLiteNode, delegateData);
160}
161} // namespace armnnOpaqueDelegate