IVGCVSW-5395 TfLiteDelegate: Implement the Softmax operators

Signed-off-by: James Ward <james.ward@arm.com>
Change-Id: I9f098c6b62ebb08e727aa8547e08bddc0b814705
diff --git a/delegate/src/Softmax.hpp b/delegate/src/Softmax.hpp
index ddadbc7..0de8e14 100644
--- a/delegate/src/Softmax.hpp
+++ b/delegate/src/Softmax.hpp
@@ -5,7 +5,7 @@
 
 #pragma once
 
-#include <armnn/utility/IgnoreUnused.hpp>
+#include "DelegateUtils.hpp"
 
 #include <tensorflow/lite/builtin_ops.h>
 #include <tensorflow/lite/c/builtin_op_data.h>
@@ -15,19 +15,133 @@
 namespace armnnDelegate
 {
 
+TfLiteStatus ValidateSoftmaxOperator(DelegateData& delegateData,
+                                     TfLiteContext* tfLiteContext,
+                                     const armnn::TensorInfo& inputInfo,
+                                     const armnn::TensorInfo& outputTensorInfo,
+                                     const armnn::SoftmaxDescriptor& descriptor)
+{
+    bool isSupported = false;
+    FORWARD_LAYER_SUPPORT_FUNC(__func__,
+                               tfLiteContext,
+                               IsSoftmaxSupported,
+                               delegateData.m_Backends,
+                               isSupported,
+                               inputInfo,
+                               outputTensorInfo,
+                               descriptor);
+    return isSupported ? kTfLiteOk : kTfLiteError;
+}
+
+
+TfLiteStatus ValidateLogSoftmaxOperator(DelegateData& delegateData,
+                                        TfLiteContext* tfLiteContext,
+                                        const armnn::TensorInfo& inputInfo,
+                                        const armnn::TensorInfo& outputTensorInfo,
+                                        const armnn::LogSoftmaxDescriptor& descriptor)
+{
+    bool isSupported = false;
+    FORWARD_LAYER_SUPPORT_FUNC(__func__,
+                               tfLiteContext,
+                               IsLogSoftmaxSupported,
+                               delegateData.m_Backends,
+                               isSupported,
+                               inputInfo,
+                               outputTensorInfo,
+                               descriptor);
+    return isSupported ? kTfLiteOk : kTfLiteError;
+}
+
 TfLiteStatus VisitSoftmaxOperator(DelegateData& delegateData,
                                   TfLiteContext* tfLiteContext,
                                   TfLiteNode* tfLiteNode,
                                   int nodeIndex,
                                   int32_t softmaxOperatorCode)
 {
-    armnn::IgnoreUnused(delegateData,
-                        tfLiteContext,
-                        tfLiteNode,
-                        nodeIndex,
-                        softmaxOperatorCode);
+    TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
+    TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
 
-    return kTfLiteError;
+    const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
+    const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
+    if (IsDynamicTensor(tfLiteInputTensor))
+    {
+        TF_LITE_MAYBE_KERNEL_LOG(
+            tfLiteContext,
+            "TfLiteArmnnDelegate: Dynamic input tensors are not supported in node #%d: ",
+            nodeIndex);
+        return kTfLiteError;
+    }
+    const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
+    if (IsDynamicTensor(tfLiteOutputTensor))
+    {
+        TF_LITE_MAYBE_KERNEL_LOG(
+            tfLiteContext,
+            "TfLiteArmnnDelegate: Dynamic output tensors are not supported in node #%d: ",
+            nodeIndex);
+        return kTfLiteError;
+    }
+
+    const armnn::TensorInfo& inputTensorInfo  = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
+    const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
+
+
+    if (!delegateData.m_Network)
+    {
+        switch(softmaxOperatorCode)
+        {
+            case kTfLiteBuiltinSoftmax:
+            {
+                armnn::SoftmaxDescriptor descriptor;
+                auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(tfLiteNode->builtin_data);
+                descriptor.m_Beta = params->beta;
+                return ValidateSoftmaxOperator(delegateData,
+                                               tfLiteContext,
+                                               inputTensorInfo,
+                                               outputTensorInfo,
+                                               descriptor);
+            }
+            case kTfLiteBuiltinLogSoftmax:
+            {
+                armnn::LogSoftmaxDescriptor descriptor;
+                return ValidateLogSoftmaxOperator(delegateData,
+                                                  tfLiteContext,
+                                                  inputTensorInfo,
+                                                  outputTensorInfo,
+                                                  descriptor);
+            }
+            default:
+                return kTfLiteError;
+        }
+    }
+
+    armnn::IConnectableLayer* softmaxLayer = nullptr;
+
+    switch(softmaxOperatorCode)
+    {
+        case kTfLiteBuiltinSoftmax:
+        {
+            armnn::SoftmaxDescriptor descriptor;
+            auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(tfLiteNode->builtin_data);
+            descriptor.m_Beta = params->beta;
+            softmaxLayer = delegateData.m_Network->AddSoftmaxLayer(descriptor);
+            break;
+        }
+        case kTfLiteBuiltinLogSoftmax:
+        {
+            armnn::LogSoftmaxDescriptor descriptor;
+            softmaxLayer = delegateData.m_Network->AddLogSoftmaxLayer(descriptor);
+            break;
+        }
+        default:
+            return kTfLiteError;
+    }
+    ARMNN_ASSERT(softmaxLayer != nullptr);
+
+    armnn::IOutputSlot& outputSlot = softmaxLayer->GetOutputSlot(0);
+    outputSlot.SetTensorInfo(outputTensorInfo);
+
+    // Connect
+    return Connect(softmaxLayer, tfLiteNode, delegateData);
 }
 
 } // namespace armnnDelegate