IVGCVSW-3456 Add support for dynamic output shape in ConvertPrelu

Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: I8fc7a716455be3f51b51177f6896a73790a41fc3
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index fe0cfbd..b194a57 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -5,6 +5,8 @@
 
 #include "HalPolicy.hpp"
 
+#include "OutputShapeUtils.hpp"
+
 #include "../1.0/HalPolicy.hpp"
 #include "../1.1/HalPolicy.hpp"
 
@@ -539,7 +541,13 @@
 
     const armnn::TensorInfo& inputInfo  = input.GetTensorInfo();
     const armnn::TensorInfo& alphaInfo  = alpha.GetTensorInfo();
-    const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
+
+    armnn::TensorInfo outputInfo = GetTensorInfoForOperand(*output);
+    if (outputInfo.GetNumElements() == 0u)
+    {
+        ALOGD("Output shape not set, will infer from inputs");
+        outputInfo.SetShape(InferPreluOutputShape(inputInfo.GetShape(), alphaInfo.GetShape()));
+    }
 
     if (!IsLayerSupportedForAnyBackend(__func__,
                                        armnn::IsPreluSupported,
@@ -560,7 +568,12 @@
 
     BroadcastTensor(input, alpha, layer, *data.m_Network);
 
-    return SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 0, *layer, model, data);
+    return SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation,
+                                                            0,
+                                                            *layer,
+                                                            model,
+                                                            data,
+                                                            armnn::Optional<armnn::TensorInfo>(outputInfo));
 }
 
 bool HalPolicy::ConvertResize(const Operation& operation,
diff --git a/Android.mk b/Android.mk
index bee57dd..215b0a8 100644
--- a/Android.mk
+++ b/Android.mk
@@ -334,6 +334,7 @@
         ConversionUtils.cpp \
         DriverOptions.cpp \
         ModelToINetworkConverter.cpp \
+        OutputShapeUtils.cpp \
         RequestThread.cpp \
         Utils.cpp
 
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index d30b8a4..c9be000 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -1028,7 +1028,8 @@
                                   armnn::IConnectableLayer& layer,
                                   uint32_t layerOutputIndex,
                                   const HalModel& model,
-                                  ConversionData& data)
+                                  ConversionData& data,
+                                  const armnn::Optional<armnn::TensorInfo>& outputInfo = armnn::EmptyOptional())
 {
     using HalOperand = typename HalPolicy::Operand;
 
@@ -1043,7 +1044,15 @@
     const uint32_t operandIndex = operation.outputs[operationOutputIndex];
     data.m_OutputSlotForOperand[operandIndex] = &outputSlot;
 
-    outputSlot.SetTensorInfo(GetTensorInfoForOperand(*outputOperand));
+    if (outputInfo.has_value())
+    {
+        outputSlot.SetTensorInfo(outputInfo.value());
+        ALOGD("Output info overwritten");
+    }
+    else
+    {
+        outputSlot.SetTensorInfo(GetTensorInfoForOperand(*outputOperand));
+    }
 
     return true;
 }
@@ -1092,9 +1101,16 @@
                                   uint32_t outputIndex,
                                   armnn::IConnectableLayer& layer,
                                   const HalModel& model,
-                                  ConversionData& data)
+                                  ConversionData& data,
+                                  const armnn::Optional<armnn::TensorInfo>& outputInfo = armnn::EmptyOptional())
 {
-    return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, outputIndex, layer, outputIndex, model, data);
+    return SetupAndTrackLayerOutputSlot<HalPolicy>(operation,
+                                                   outputIndex,
+                                                   layer,
+                                                   outputIndex,
+                                                   model,
+                                                   data,
+                                                   outputInfo);
 }
 
 template<typename HalPolicy,
diff --git a/OutputShapeUtils.cpp b/OutputShapeUtils.cpp
new file mode 100644
index 0000000..de27630
--- /dev/null
+++ b/OutputShapeUtils.cpp
@@ -0,0 +1,43 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "OutputShapeUtils.hpp"
+
+#include <algorithm>
+
+namespace armnn_driver
+{
+
+using namespace armnn;
+
+TensorShape InferPreluOutputShape(const TensorShape& inputShape, const TensorShape& alphaShape)
+{
+    // NOTE: The inferred PReLU output size will be the maximum size along each dimension
+    // of input and alpha, starting with the trailing dimensions, and working its way forward.
+    //
+    // Example: inputShape={4, 1, 2}, alphaShape={5, 4, 3, 1} => outputShape={5, 4, 3, 2}
+
+    const unsigned int numInputDims = inputShape.GetNumDimensions();
+    const unsigned int numAlphaDims = alphaShape.GetNumDimensions();
+
+    const unsigned int maxNumDims = std::max(numInputDims, numAlphaDims);
+
+    TensorShape outputShape = TensorShape(maxNumDims);
+    for (unsigned int reverseIdx = 1u; reverseIdx <= maxNumDims; ++reverseIdx)
+    {
+        const int inputIdx = numInputDims - reverseIdx;
+        const int alphaIdx = numAlphaDims - reverseIdx;
+
+        const unsigned int inputDimSize = inputIdx >= 0 ? inputShape[inputIdx] : 0u;
+        const unsigned int alphaDimSize = alphaIdx >= 0 ? alphaShape[alphaIdx] : 0u;
+
+        const unsigned int outputIdx = maxNumDims - reverseIdx;
+        outputShape[outputIdx] = std::max(inputDimSize, alphaDimSize);
+    }
+
+    return outputShape;
+}
+
+} // namespace armnn_driver
\ No newline at end of file
diff --git a/OutputShapeUtils.hpp b/OutputShapeUtils.hpp
new file mode 100644
index 0000000..f314252
--- /dev/null
+++ b/OutputShapeUtils.hpp
@@ -0,0 +1,17 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/ArmNN.hpp>
+
+namespace armnn_driver
+{
+
+armnn::TensorShape InferPreluOutputShape(const armnn::TensorShape& inputShape, const armnn::TensorShape& alphaShape);
+
+} // namespace armnn_driver
+
+