IVGCVSW-5274 'Update ConvertQuantizedLstm function to use ShapeInferenceMethod'

* Enabled Dynamic Tensors in QUANTIZED_LSTM operator.

!android-nn-driver:3897

Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: I415014d19729aac255479099e372e5ff1a6dd3e2
diff --git a/ConversionUtils_1_3.hpp b/ConversionUtils_1_3.hpp
index e696125..445b9ea 100644
--- a/ConversionUtils_1_3.hpp
+++ b/ConversionUtils_1_3.hpp
@@ -600,29 +600,36 @@
     }
 
     // Check if the layer is supported
-
-    if (IsDynamicTensor(constOutputStateOutInfo) ||
-        IsDynamicTensor(cellStateOutInfo)   ||
-        IsDynamicTensor(constOutputInfo))
+    bool isSupported = false;
+    auto validateFunc = [&](const armnn::TensorInfo& cellStateOutInfo, bool& isSupported)
     {
-        return Fail("%s: Dynamic output tensors are not supported %d %d %d %d", __func__,
-                    IsDynamicTensor(constOutputStateOutInfo), IsDynamicTensor(cellStateOutInfo),
-                    IsDynamicTensor(constOutputInfo));
+        FORWARD_LAYER_SUPPORT_FUNC(__func__,
+                                   IsQLstmSupported,
+                                   data.m_Backends,
+                                   isSupported,
+                                   inputInfo,
+                                   outputStatePrevTimeStepInfo,
+                                   cellStatePrevTimeStepInfo,
+                                   constOutputStateOutInfo,
+                                   cellStateOutInfo,
+                                   constOutputInfo,
+                                   desc,
+                                   paramsInfo);
+    };
+
+    bool isDynamic = false;
+    if (!IsDynamicTensor(constOutputStateOutInfo) &&
+        !IsDynamicTensor(cellStateOutInfo)  &&
+        !IsDynamicTensor(constOutputInfo))
+    {
+        validateFunc(outputInfo, isSupported);
+    }
+    else
+    {
+        isDynamic = true;
+        isSupported = AreDynamicTensorsSupported();
     }
 
-    bool isSupported = false;
-    FORWARD_LAYER_SUPPORT_FUNC(__func__,
-                               IsQLstmSupported,
-                               data.m_Backends,
-                               isSupported,
-                               inputInfo,
-                               outputStatePrevTimeStepInfo,
-                               cellStatePrevTimeStepInfo,
-                               constOutputStateOutInfo,
-                               cellStateOutInfo,
-                               constOutputInfo,
-                               desc,
-                               paramsInfo);
     if (!isSupported)
     {
         return false;
@@ -635,10 +642,21 @@
     outputStatePrevTimeStep.Connect(layer->GetInputSlot(1));
     cellStatePrevTimeStep.Connect(layer->GetInputSlot(2));
 
-    return ( SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data,
-                                                     &constOutputStateOutInfo) &&
-             SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data) &&
-             SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data, &constOutputInfo));
+    if (!isDynamic)
+    {
+        return ( SetupAndTrackLayerOutputSlot<HalPolicy>(
+                       operation, 0, *layer, 0, model, data, &constOutputStateOutInfo) &&
+                 SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data) &&
+                 SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data, &constOutputInfo));
+    }
+    else
+    {
+        return ( SetupAndTrackLayerOutputSlot<HalPolicy>(
+                       operation, 0, *layer, 0, model, data, &constOutputStateOutInfo) &&
+                 SetupAndTrackLayerOutputSlot<HalPolicy>(
+                       operation, 1, *layer, 1, model, data, nullptr, validateFunc, true) &&
+                 SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data, &constOutputInfo));
+    }
 }
 
 template<typename HalPolicy,
diff --git a/test/1.3/QLstm.cpp b/test/1.3/QLstm.cpp
index 608e408..f58ed0a 100644
--- a/test/1.3/QLstm.cpp
+++ b/test/1.3/QLstm.cpp
@@ -863,6 +863,167 @@
                   compute);
 }
 
+void DynamicOutputQLstmWithNoProjection(armnn::Compute compute)
+{
+    // This replicates android/frameworks/ml/nn/runtime/test/specs/V1_3/qlstm_noprojection.mod.py
+    // with values from android/frameworks/ml/nn/runtime/test/generated/spec_V1_3/qlstm_noprojection.example.cpp
+    // and weights, biases and scalars passed as CONSTANT_COPY tensors (instead of SUBGRAPH_INPUT tensors)
+    // and made cellStateOutput dynamic.
+
+    uint32_t batchSize  = 2;
+    uint32_t inputSize  = 5;
+    uint32_t outputSize = 4;
+    uint32_t numUnits   = 4;
+
+    // Inputs:
+    hidl_vec<uint32_t> inputDimensions{batchSize, inputSize};
+    std::vector<int8_t> inputValue { 90, 102, 13, 26, 38, 102, 13, 26, 51, 64 };
+
+    hidl_vec<uint32_t> inputToInputWeightsDimensions{0, 0};
+    std::vector<int8_t> inputToInputWeightsValue;
+
+    hidl_vec<uint32_t> inputToForgetWeightsDimensions{numUnits, inputSize};
+    std::vector<int8_t> inputToForgetWeightsValue { -77, -13,  38,  25,  115,
+                                                    -64, -25, -51,  38, -102,
+                                                    -51,  38, -64, -51,  -77,
+                                                    38, -51, -77, -64,  -64 };
+
+    hidl_vec<uint32_t> inputToCellWeightsDimensions{numUnits, inputSize};
+    std::vector<int8_t> inputToCellWeightsValue { -51,  -38, -25, -13, -64,
+                                                  64,  -25, -38, -25, -77,
+                                                  77,  -13, -51, -38, -89,
+                                                  89, -115, -64, 102,  77 };
+
+    hidl_vec<uint32_t> inputToOutputWeightsDimensions{numUnits, inputSize};
+    std::vector<int8_t> inputToOutputWeightsValue { -102, -51, -25, -115, -13,
+                                                    -89,  38, -38, -102, -25,
+                                                    77, -25,  51,  -89, -38,
+                                                    -64,  13,  64,  -77, -51 };
+
+    hidl_vec<uint32_t> recurrentToInputWeightsDimensions{0, 0};
+    std::vector<int8_t> recurrentToInputWeightsValue;
+
+    hidl_vec<uint32_t> recurrentToForgetWeightsDimensions{numUnits, outputSize};
+    std::vector<int8_t> recurrentToForgetWeightsValue { -64, -38, -64, -25,
+                                                        77,  51, 115,  38,
+                                                        -13,  25,  64,  25,
+                                                        25,  38, -13,  51 };
+
+    hidl_vec<uint32_t> recurrentToCellWeightsDimensions{numUnits, outputSize};
+    std::vector<int8_t> recurrentToCellWeightsValue { -38,  25,  13, -38,
+                                                      102, -10, -25,  38,
+                                                      102, -77, -13,  25,
+                                                      38, -13,  25,  64 };
+
+    hidl_vec<uint32_t> recurrentToOutputWeightsDimensions{numUnits, outputSize};
+    std::vector<int8_t> recurrentToOutputWeightsValue {  38, -13,  13, -25,
+                                                         -64, -89, -25, -77,
+                                                         -13, -51, -89, -25,
+                                                         13,  64,  25, -38 };
+
+    hidl_vec<uint32_t> cellToInputWeightsDimensions{0};
+    std::vector<int16_t> cellToInputWeightsValue;
+
+    hidl_vec<uint32_t> cellToForgetWeightsDimensions{0};
+    std::vector<int16_t> cellToForgetWeightsValue;
+
+    hidl_vec<uint32_t> cellToOutputWeightsDimensions{0};
+    std::vector<int16_t> cellToOutputWeightsValue;
+
+    hidl_vec<uint32_t> inputGateBiasDimensions{0};
+    std::vector<int32_t> inputGateBiasValue;
+
+    hidl_vec<uint32_t> forgetGateBiasDimensions{numUnits};
+    std::vector<int32_t> forgetGateBiasValue { 2147484, -6442451, -4294968, 2147484 };
+
+    hidl_vec<uint32_t> cellBiasDimensions{numUnits};
+    std::vector<int32_t> cellBiasValue { -1073742, 15461883, 5368709, 1717987 };
+
+    hidl_vec<uint32_t> outputGateBiasDimensions{numUnits};
+    std::vector<int32_t> outputGateBiasValue { 1073742, -214748, 4294968, 2147484 };
+
+    hidl_vec<uint32_t> projectionWeightsDimensions{0, 0};
+    std::vector<int8_t> projectionWeightsValue;
+
+    hidl_vec<uint32_t> projectionBiasDimensions{0};
+    std::vector<int32_t> projectionBiasValue;
+
+    hidl_vec<uint32_t> outputStateInDimensions{batchSize, outputSize};
+    std::vector<int8_t> outputStateInValue { 0, 0, 0, 0, 0, 0, 0, 0 };
+
+    hidl_vec<uint32_t> cellStateInDimensions{batchSize, numUnits};
+    std::vector<int16_t> cellStateInValue { 0, 0, 0, 0, 0, 0, 0, 0 };
+
+    // Normalization:
+    hidl_vec<uint32_t> inputLayerNormWeightsDimensions{0};
+    std::vector<int16_t> inputLayerNormWeightsValue;
+
+    hidl_vec<uint32_t> forgetLayerNormWeightsDimensions{numUnits};
+    std::vector<int16_t> forgetLayerNormWeightsValue { 6553, 6553, 13107, 9830 };
+
+    hidl_vec<uint32_t> cellLayerNormWeightsDimensions{numUnits};
+    std::vector<int16_t> cellLayerNormWeightsValue { 22937, 6553, 9830, 26214 };
+
+    hidl_vec<uint32_t> outputLayerNormWeightsDimensions{numUnits};
+    std::vector<int16_t> outputLayerNormWeightsValue { 19660, 6553, 6553, 16384 };
+
+    float cellClipValue           = 0.0f;
+    float projectionClipValue     = 0.0f;
+    float inputIntermediateScale  = 0.007059f;
+    float forgetIntermediateScale = 0.007812f;
+    float cellIntermediateScale   = 0.007059f;
+    float outputIntermediateScale = 0.007812f;
+    int32_t hiddenStateZeroPoint  = 0;
+    float hiddenStateScale        = 0.007f;
+
+    // Outputs:
+    hidl_vec<uint32_t> outputStateOutDimensions{batchSize, outputSize};
+    std::vector<int8_t> outputStateOutValue { -15, 21, 14, 20, -15, 15, 5, 27 };
+
+    hidl_vec<uint32_t> cellStateOutDimensions{};
+    std::vector<int16_t> cellStateOutValue { -11692, 9960, 5491, 8861, -9422, 7726, 2056, 13149 };
+
+    hidl_vec<uint32_t> outputDimensions{batchSize, outputSize};
+    std::vector<int8_t> outputValue { -15, 21, 14, 20, -15, 15, 5, 27 };
+
+    QLstmTestImpl(inputDimensions,                       inputValue,
+                  inputToInputWeightsDimensions,         inputToInputWeightsValue,
+                  inputToForgetWeightsDimensions,        inputToForgetWeightsValue,
+                  inputToCellWeightsDimensions,          inputToCellWeightsValue,
+                  inputToOutputWeightsDimensions,        inputToOutputWeightsValue,
+                  recurrentToInputWeightsDimensions,     recurrentToInputWeightsValue,
+                  recurrentToForgetWeightsDimensions,    recurrentToForgetWeightsValue,
+                  recurrentToCellWeightsDimensions,      recurrentToCellWeightsValue,
+                  recurrentToOutputWeightsDimensions,    recurrentToOutputWeightsValue,
+                  cellToInputWeightsDimensions,          cellToInputWeightsValue,
+                  cellToForgetWeightsDimensions,         cellToForgetWeightsValue,
+                  cellToOutputWeightsDimensions,         cellToOutputWeightsValue,
+                  inputGateBiasDimensions,               inputGateBiasValue,
+                  forgetGateBiasDimensions,              forgetGateBiasValue,
+                  cellBiasDimensions,                    cellBiasValue,
+                  outputGateBiasDimensions,              outputGateBiasValue,
+                  projectionWeightsDimensions,           projectionWeightsValue,
+                  projectionBiasDimensions,              projectionBiasValue,
+                  outputStateInDimensions,               outputStateInValue,
+                  cellStateInDimensions,                 cellStateInValue,
+                  inputLayerNormWeightsDimensions,       inputLayerNormWeightsValue,
+                  forgetLayerNormWeightsDimensions,      forgetLayerNormWeightsValue,
+                  cellLayerNormWeightsDimensions,        cellLayerNormWeightsValue,
+                  outputLayerNormWeightsDimensions,      outputLayerNormWeightsValue,
+                  cellClipValue,
+                  projectionClipValue,
+                  inputIntermediateScale,
+                  forgetIntermediateScale,
+                  cellIntermediateScale,
+                  outputIntermediateScale,
+                  hiddenStateZeroPoint,
+                  hiddenStateScale,
+                  outputStateOutDimensions,              outputStateOutValue,
+                  cellStateOutDimensions,                cellStateOutValue,
+                  outputDimensions,                      outputValue,
+                  compute);
+}
+
 } // anonymous namespace
 
 // Support is not added yet
@@ -876,4 +1037,9 @@
     QLstmWithNoProjection(sample);
 }
 
+BOOST_DATA_TEST_CASE(DynamicOutputQLSTMWithNoProjectionTest, COMPUTE_DEVICES)
+{
+    DynamicOutputQLstmWithNoProjection(sample);
+}
+
 BOOST_AUTO_TEST_SUITE_END()
\ No newline at end of file