IVGCVSW-3236 Extend Ref LSTM with layer normalization support

* Add descriptor values
* Update lstm queue descriptor validate function
* Update lstm workload
* Update isLstmSupported (Cl and Ref), LayerSupportBase, ILayerSupport
* Update lstm layer
* Add unit tests

Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Change-Id: I932175d550facfb342325051eaa7bd2084ebdc18
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
index 7c7af2d..c696098 100644
--- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
+++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
@@ -453,22 +453,139 @@
 
 BOOST_AUTO_TEST_CASE(LstmQueueDescriptor_Validate)
 {
-    armnn::TensorInfo inputTensorInfo;
-    armnn::TensorInfo outputTensorInfo;
+    armnn::DataType dataType = armnn::DataType::Float32;
 
-    unsigned int inputShape[] = { 1, 2 };
-    unsigned int outputShape[] = { 1 };
+    float qScale = 0.0f;
+    int32_t qOffset = 0;
 
-    inputTensorInfo = armnn::TensorInfo(2, inputShape, armnn::DataType::Float32);
-    outputTensorInfo = armnn::TensorInfo(1, outputShape, armnn::DataType::Float32);
+    unsigned int batchSize = 2;
+    unsigned int outputSize = 3;
+    unsigned int inputSize = 5;
+    unsigned numUnits = 4;
 
-    LstmQueueDescriptor invalidData;
-    WorkloadInfo        invalidInfo;
+    armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, dataType,  qScale, qOffset );
+    armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, dataType, qScale, qOffset);
+    armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, dataType, qScale, qOffset);
 
-    AddInputToWorkload(invalidData, invalidInfo, inputTensorInfo, nullptr);
-    AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr);
+    // Scratch buffer size with CIFG [batchSize, numUnits * 4]
+    armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, dataType, qScale, qOffset);
+    armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, dataType, qScale, qOffset);
+    armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, dataType, qScale, qOffset);
+    armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, dataType, qScale, qOffset);
 
-    BOOST_CHECK_THROW(invalidData.Validate(invalidInfo), armnn::InvalidArgumentException);
+    armnn::TensorInfo tensorInfo3({outputSize}, dataType, qScale, qOffset);
+    armnn::TensorInfo tensorInfo4({numUnits}, dataType, qScale, qOffset);
+    armnn::TensorInfo tensorInfo4x5({numUnits, inputSize}, dataType, qScale, qOffset);
+    armnn::TensorInfo tensorInfo4x3({numUnits, outputSize}, dataType, qScale, qOffset);
+    armnn::TensorInfo tensorInfo3x4({outputSize, numUnits}, dataType, qScale, qOffset);
+
+    LstmQueueDescriptor data;
+    WorkloadInfo        info;
+
+    AddInputToWorkload(data, info, inputTensorInfo, nullptr);
+    AddInputToWorkload(data, info, outputStateInTensorInfo, nullptr);
+    AddInputToWorkload(data, info, cellStateInTensorInfo, nullptr);
+
+    AddOutputToWorkload(data, info, scratchBufferTensorInfo, nullptr);
+    AddOutputToWorkload(data, info, outputStateOutTensorInfo, nullptr);
+    AddOutputToWorkload(data, info, cellStateOutTensorInfo, nullptr);
+    // AddOutputToWorkload(data, info, outputTensorInfo, nullptr); is left out
+
+    armnn::ScopedCpuTensorHandle inputToInputWeightsTensor(tensorInfo4x5);
+    armnn::ScopedCpuTensorHandle inputToForgetWeightsTensor(tensorInfo4x5);
+    armnn::ScopedCpuTensorHandle inputToCellWeightsTensor(tensorInfo4x5);
+    armnn::ScopedCpuTensorHandle inputToOutputWeightsTensor(tensorInfo4x5);
+    armnn::ScopedCpuTensorHandle recurrentToForgetWeightsTensor(tensorInfo4x3);
+    armnn::ScopedCpuTensorHandle recurrentToInputWeightsTensor(tensorInfo4x3);
+    armnn::ScopedCpuTensorHandle recurrentToCellWeightsTensor(tensorInfo4x3);
+    armnn::ScopedCpuTensorHandle recurrentToOutputWeightsTensor(tensorInfo4x3);
+    armnn::ScopedCpuTensorHandle cellToInputWeightsTensor(tensorInfo4);
+    armnn::ScopedCpuTensorHandle inputGateBiasTensor(tensorInfo4);
+    armnn::ScopedCpuTensorHandle forgetGateBiasTensor(tensorInfo4);
+    armnn::ScopedCpuTensorHandle cellBiasTensor(tensorInfo4);
+    armnn::ScopedCpuTensorHandle outputGateBiasTensor(tensorInfo4);
+    armnn::ScopedCpuTensorHandle cellToForgetWeightsTensor(tensorInfo4);
+    armnn::ScopedCpuTensorHandle cellToOutputWeightsTensor(tensorInfo4);
+    armnn::ScopedCpuTensorHandle projectionWeightsTensor(tensorInfo3x4);
+    armnn::ScopedCpuTensorHandle projectionBiasTensor(tensorInfo3);
+    armnn::ScopedCpuTensorHandle inputLayerNormWeightsTensor(tensorInfo4);
+    armnn::ScopedCpuTensorHandle forgetLayerNormWeightsTensor(tensorInfo4);
+    armnn::ScopedCpuTensorHandle cellLayerNormWeightsTensor(tensorInfo4);
+    armnn::ScopedCpuTensorHandle outputLayerNormWeightsTensor(tensorInfo4);
+
+    data.m_InputToInputWeights = &inputToInputWeightsTensor;
+    data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
+    data.m_InputToCellWeights = &inputToCellWeightsTensor;
+    data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
+    data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
+    data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
+    data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
+    data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
+    data.m_CellToInputWeights = &cellToInputWeightsTensor;
+    data.m_InputGateBias = &inputGateBiasTensor;
+    data.m_ForgetGateBias = &forgetGateBiasTensor;
+    data.m_CellBias = &cellBiasTensor;
+    data.m_OutputGateBias = &outputGateBiasTensor;
+    data.m_CellToForgetWeights = &cellToForgetWeightsTensor;
+    data.m_CellToOutputWeights = &cellToOutputWeightsTensor;
+    data.m_ProjectionWeights = &projectionWeightsTensor;
+    data.m_ProjectionBias = &projectionBiasTensor;
+
+    data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor;
+    data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
+    data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
+    data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
+
+    // Flags to set test configuration
+    data.m_Parameters.m_ActivationFunc = 4;
+    data.m_Parameters.m_CifgEnabled = false;
+    data.m_Parameters.m_PeepholeEnabled = true;
+    data.m_Parameters.m_ProjectionEnabled = true;
+    data.m_Parameters.m_LayerNormEnabled = true;
+
+    // check wrong number of outputs
+    BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+    AddOutputToWorkload(data, info, outputTensorInfo, nullptr);
+
+    // check wrong cifg parameter configuration
+    data.m_Parameters.m_CifgEnabled = true;
+    armnn::TensorInfo scratchBufferTensorInfo2({batchSize, numUnits * 3}, dataType, qScale, qOffset);
+    SetWorkloadOutput(data, info, 0, scratchBufferTensorInfo2, nullptr);
+    BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+    data.m_Parameters.m_CifgEnabled = false;
+    SetWorkloadOutput(data, info, 0, scratchBufferTensorInfo, nullptr);
+
+    // check wrong inputGateBias configuration
+    data.m_InputGateBias = nullptr;
+    BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+    data.m_InputGateBias = &inputGateBiasTensor;
+
+    // check inconsistant projection parameters
+    data.m_Parameters.m_ProjectionEnabled = false;
+    BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+    data.m_Parameters.m_ProjectionEnabled = true;
+    data.m_ProjectionWeights = nullptr;
+    BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+    data.m_ProjectionWeights = &projectionWeightsTensor;
+
+    // check missing input layer normalisation weights
+    data.m_InputLayerNormWeights = nullptr;
+    BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+    data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor;
+
+    // layer norm disabled but normalisation weights are present
+    data.m_Parameters.m_LayerNormEnabled = false;
+    BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+    data.m_Parameters.m_LayerNormEnabled = true;
+
+    // check invalid outputTensor shape
+    armnn::TensorInfo incorrectOutputTensorInfo({batchSize, outputSize + 1}, dataType, qScale, qOffset);
+    SetWorkloadOutput(data, info, 3, incorrectOutputTensorInfo, nullptr);
+    BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+    SetWorkloadOutput(data, info, 3, outputTensorInfo, nullptr);
+
+    // check correct configuration
+    BOOST_CHECK_NO_THROW(data.Validate(info));
 }
 
 BOOST_AUTO_TEST_SUITE_END()