MLCE-133: Driver infers scratchBuffer shape.

Change-Id: I7977d697772349b8ea7eb300937409ce0a3a4dee
Signed-off-by: Pablo Tello <pablo.tello@arm.com>
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index 60bbf1d..ac78e96 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -2101,6 +2101,36 @@
     return ::ConvertTanH<hal_1_2::HalPolicy>(operation, model, data);
 }
 
+template<typename HalPolicy,
+         typename HalOperation = typename HalPolicy::Operation,
+         typename HalModel     = typename HalPolicy::Model>
+bool SetupAndTrackLayerOutputSlotAndOverrideTensorInfo(const HalOperation& operation,
+                                  uint32_t operationOutputIndex,
+                                  armnn::IConnectableLayer& layer,
+                                  uint32_t layerOutputIndex,
+                                  const HalModel& model,
+                                  ConversionData& data,
+                                  const armnn::TensorInfo tensor_info)
+{
+    using HalOperand = typename HalPolicy::Operand;
+
+    const HalOperand* outputOperand = GetOutputOperand<HalPolicy>(operation, operationOutputIndex, model);
+    if ((outputOperand == nullptr) || (operationOutputIndex >= layer.GetNumOutputSlots()))
+    {
+        return false;
+    }
+
+    armnn::IOutputSlot& outputSlot = layer.GetOutputSlot(layerOutputIndex);
+
+    const uint32_t operandIndex = operation.outputs[operationOutputIndex];
+    data.m_OutputSlotForOperand[operandIndex] = &outputSlot;
+
+    outputSlot.SetTensorInfo(tensor_info);
+
+    return true;
+}
+
+
 bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, ConversionData& data)
 {
     ALOGV("hal_1_2::HalPolicy::ConvertLstm()");
@@ -2399,8 +2429,28 @@
     const TensorInfo& cellStateOutInfo   = GetTensorInfoForOperand(*cellStateOut);
     const TensorInfo& outputInfo         = GetTensorInfoForOperand(*output);
 
-    if (IsDynamicTensor(scratchBufferInfo)  ||
-        IsDynamicTensor(outputStateOutInfo) ||
+    // Check if the scratch buffer shape was initialized,
+    // In some cases the shape could be (0,0) which requires the driver
+    // to infer the shape and set it up accordingly.
+    // The code below does that.
+    TensorInfo fixSbInfo = scratchBufferInfo;
+    if (IsDynamicTensor(scratchBufferInfo))
+    {
+        auto & s = fixSbInfo.GetShape();
+        s[0] = outputStateInInfo.GetShape()[0];
+        if (desc.m_CifgEnabled)
+        {
+           // 2D tensor with dimensions [num_units * 3, batch_size] with CIFG
+           s[1] = cellStateOutInfo.GetShape()[1]*3;
+        }
+        else
+        {
+          // scratch_buffer [num_units * 4, batch_size] without CIFG
+          s[1] = cellStateOutInfo.GetShape()[1]*4;
+        }
+    }
+
+    if (IsDynamicTensor(outputStateOutInfo) ||
         IsDynamicTensor(cellStateOutInfo)   ||
         IsDynamicTensor(outputInfo))
     {
@@ -2467,7 +2517,7 @@
                                inputInfo,
                                outputStateInInfo,
                                cellStateInInfo,
-                               scratchBufferInfo,
+                               fixSbInfo,
                                outputStateOutInfo,
                                cellStateOutInfo,
                                outputInfo,
@@ -2485,7 +2535,13 @@
     outputStateIn.Connect(layer->GetInputSlot(1));
     cellStateIn.Connect(layer->GetInputSlot(2));
 
-    return (SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 0, *layer, 0, model, data) &&
+
+    return (
+            (IsDynamicTensor(scratchBufferInfo)?
+                SetupAndTrackLayerOutputSlotAndOverrideTensorInfo<hal_1_2::HalPolicy>(
+                    operation, 0, *layer, 0, model, data,fixSbInfo):
+                SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(
+                    operation, 0, *layer, 0, model, data)) &&
             SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 1, *layer, 1, model, data) &&
             SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 2, *layer, 2, model, data) &&
             SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 3, *layer, 3, model, data));