blob: f2570910fb76802479fdfbdb27a680045d33359c [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
James Warda8578102020-11-13 18:05:04 +00008#include "DelegateUtils.hpp"
Finn Williams6f9f9902020-11-13 13:23:15 +00009
Sadik Armagan62483be2020-10-23 17:14:43 +010010#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnDelegate
16{
17
James Warda8578102020-11-13 18:05:04 +000018TfLiteStatus ValidateSoftmaxOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 const armnn::TensorInfo& inputInfo,
21 const armnn::TensorInfo& outputTensorInfo,
22 const armnn::SoftmaxDescriptor& descriptor)
23{
24 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +000025 FORWARD_LAYER_SUPPORT_FUNC("SOFTMAX",
James Warda8578102020-11-13 18:05:04 +000026 tfLiteContext,
27 IsSoftmaxSupported,
28 delegateData.m_Backends,
29 isSupported,
30 inputInfo,
31 outputTensorInfo,
32 descriptor);
33 return isSupported ? kTfLiteOk : kTfLiteError;
34}
35
36
37TfLiteStatus ValidateLogSoftmaxOperator(DelegateData& delegateData,
38 TfLiteContext* tfLiteContext,
39 const armnn::TensorInfo& inputInfo,
40 const armnn::TensorInfo& outputTensorInfo,
41 const armnn::LogSoftmaxDescriptor& descriptor)
42{
43 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +000044 FORWARD_LAYER_SUPPORT_FUNC("LOG_SOFTMAX",
James Warda8578102020-11-13 18:05:04 +000045 tfLiteContext,
46 IsLogSoftmaxSupported,
47 delegateData.m_Backends,
48 isSupported,
49 inputInfo,
50 outputTensorInfo,
51 descriptor);
52 return isSupported ? kTfLiteOk : kTfLiteError;
53}
54
Sadik Armagan62483be2020-10-23 17:14:43 +010055TfLiteStatus VisitSoftmaxOperator(DelegateData& delegateData,
56 TfLiteContext* tfLiteContext,
57 TfLiteNode* tfLiteNode,
58 int nodeIndex,
59 int32_t softmaxOperatorCode)
60{
James Warda8578102020-11-13 18:05:04 +000061 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
62 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Finn Williams6f9f9902020-11-13 13:23:15 +000063
James Warda8578102020-11-13 18:05:04 +000064 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
65 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
66 if (IsDynamicTensor(tfLiteInputTensor))
67 {
68 TF_LITE_MAYBE_KERNEL_LOG(
69 tfLiteContext,
70 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in node #%d: ",
71 nodeIndex);
72 return kTfLiteError;
73 }
74 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
75 if (IsDynamicTensor(tfLiteOutputTensor))
76 {
77 TF_LITE_MAYBE_KERNEL_LOG(
78 tfLiteContext,
79 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in node #%d: ",
80 nodeIndex);
81 return kTfLiteError;
82 }
83
84 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
85 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
86
87
88 if (!delegateData.m_Network)
89 {
90 switch(softmaxOperatorCode)
91 {
92 case kTfLiteBuiltinSoftmax:
93 {
94 armnn::SoftmaxDescriptor descriptor;
95 auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(tfLiteNode->builtin_data);
96 descriptor.m_Beta = params->beta;
97 return ValidateSoftmaxOperator(delegateData,
98 tfLiteContext,
99 inputTensorInfo,
100 outputTensorInfo,
101 descriptor);
102 }
103 case kTfLiteBuiltinLogSoftmax:
104 {
105 armnn::LogSoftmaxDescriptor descriptor;
106 return ValidateLogSoftmaxOperator(delegateData,
107 tfLiteContext,
108 inputTensorInfo,
109 outputTensorInfo,
110 descriptor);
111 }
112 default:
113 return kTfLiteError;
114 }
115 }
116
117 armnn::IConnectableLayer* softmaxLayer = nullptr;
118
119 switch(softmaxOperatorCode)
120 {
121 case kTfLiteBuiltinSoftmax:
122 {
123 armnn::SoftmaxDescriptor descriptor;
124 auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(tfLiteNode->builtin_data);
125 descriptor.m_Beta = params->beta;
126 softmaxLayer = delegateData.m_Network->AddSoftmaxLayer(descriptor);
127 break;
128 }
129 case kTfLiteBuiltinLogSoftmax:
130 {
131 armnn::LogSoftmaxDescriptor descriptor;
132 softmaxLayer = delegateData.m_Network->AddLogSoftmaxLayer(descriptor);
133 break;
134 }
135 default:
136 return kTfLiteError;
137 }
138 ARMNN_ASSERT(softmaxLayer != nullptr);
139
140 armnn::IOutputSlot& outputSlot = softmaxLayer->GetOutputSlot(0);
141 outputSlot.SetTensorInfo(outputTensorInfo);
142
143 // Connect
144 return Connect(softmaxLayer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100145}
146
147} // namespace armnnDelegate