IVGCVSW-3236 Extend Ref LSTM with layer normalization support

* Add descriptor values
* Update lstm queue descriptor validate function
* Update lstm workload
* Update isLstmSupported (Cl and Ref), LayerSupportBase, ILayerSupport
* Update lstm layer
* Add unit tests

Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Change-Id: I932175d550facfb342325051eaa7bd2084ebdc18
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index b563bad..3d260c5 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -861,7 +861,11 @@
                                       const TensorInfo* projectionBias,
                                       const TensorInfo* cellToForgetWeights,
                                       const TensorInfo* cellToOutputWeights,
-                                      Optional<std::string&> reasonIfUnsupported) const
+                                      Optional<std::string&> reasonIfUnsupported,
+                                      const TensorInfo* inputLayerNormWeights,
+                                      const TensorInfo* forgetLayerNormWeights,
+                                      const TensorInfo* cellLayerNormWeights,
+                                      const TensorInfo* outputLayerNormWeights) const
 {
     ignore_unused(descriptor);
     ignore_unused(inputToForgetWeights);
@@ -881,6 +885,10 @@
     ignore_unused(projectionBias);
     ignore_unused(cellToForgetWeights);
     ignore_unused(cellToOutputWeights);
+    ignore_unused(inputLayerNormWeights);
+    ignore_unused(forgetLayerNormWeights);
+    ignore_unused(cellLayerNormWeights);
+    ignore_unused(outputLayerNormWeights);
 
     bool supported = true;
 
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 22b007b..ead4d1c 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -155,7 +155,11 @@
                          const TensorInfo* projectionBias,
                          const TensorInfo* cellToForgetWeights,
                          const TensorInfo* cellToOutputWeights,
-                         Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+                         Optional<std::string&> reasonIfUnsupported = EmptyOptional(),
+                         const TensorInfo* inputLayerNormWeights = nullptr,
+                         const TensorInfo* forgetLayerNormWeights = nullptr,
+                         const TensorInfo* cellLayerNormWeights = nullptr,
+                         const TensorInfo* outputLayerNormWeights = nullptr) const override;
 
     bool IsMaximumSupported(const TensorInfo& input0,
                             const TensorInfo& input1,
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 12e5774..a736a88 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -22,6 +22,7 @@
         workloads/ElementwiseFunction.cpp \
         workloads/FullyConnected.cpp \
         workloads/Gather.cpp \
+        workloads/LstmUtils.cpp \
         workloads/Mean.cpp \
         workloads/Concatenate.cpp \
         workloads/Pad.cpp \
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 7797f17..9f89c8c 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -827,6 +827,17 @@
 ARMNN_AUTO_TEST_CASE(PermuteFloat32ValueSet3, PermuteFloat32ValueSet3Test)
 
 // Lstm
+BOOST_AUTO_TEST_CASE(LstmUtilsZeroVector) {
+                     LstmUtilsZeroVectorTest(); }
+BOOST_AUTO_TEST_CASE(LstmUtilsMeanStddevNormalization) {
+                     LstmUtilsMeanStddevNormalizationNoneZeroInputTest();
+                     LstmUtilsMeanStddevNormalizationAllZeroInputTest();
+                     LstmUtilsMeanStddevNormalizationMixedZeroInputTest(); }
+BOOST_AUTO_TEST_CASE(LstmUtilsVectorBatchVectorCwiseProduct) {
+                     LstmUtilsVectorBatchVectorCwiseProductTest(); }
+BOOST_AUTO_TEST_CASE(LstmUtilsVectorBatchVectorAdd) {
+                     LstmUtilsVectorBatchVectorAddTest(); }
+
 ARMNN_AUTO_TEST_CASE(LstmLayerFloat32WithCifgWithPeepholeNoProjection,
                      LstmLayerFloat32WithCifgWithPeepholeNoProjectionTest)
 ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgNoPeepholeNoProjection,
@@ -834,6 +845,9 @@
 ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjection,
                      LstmLayerFloat32NoCifgWithPeepholeWithProjectionTest)
 
+ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNorm,
+                     LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest)
+
 ARMNN_AUTO_TEST_CASE(LstmLayerInt16NoCifgNoPeepholeNoProjection,
                      LstmLayerInt16NoCifgNoPeepholeNoProjectionTest)
 ARMNN_AUTO_TEST_CASE(LstmLayerInt16WithCifgWithPeepholeNoProjection,
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 3c0af01..696605d 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -28,6 +28,7 @@
     Gather.cpp
     Gather.hpp
     LstmUtils.hpp
+    LstmUtils.cpp
     Maximum.hpp
     Mean.cpp
     Mean.hpp
diff --git a/src/backends/reference/workloads/LstmUtils.cpp b/src/backends/reference/workloads/LstmUtils.cpp
new file mode 100644
index 0000000..f197aae
--- /dev/null
+++ b/src/backends/reference/workloads/LstmUtils.cpp
@@ -0,0 +1,307 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+//#pragma once
+
+#include "LstmUtils.hpp"
+#include "BaseIterator.hpp"
+#include <backendsCommon/CpuTensorHandle.hpp>
+
+
+// Helper functions ported from the Android code base
+// Refer to: android/external/tensorflow/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+
+void VectorBatchVectorAdd(armnn::Decoder<float>& vector,
+                          uint32_t vSize,
+                          armnn::Decoder<float>& batchVector,
+                          uint32_t nBatch,
+                          armnn::Encoder<float>& outResult )
+{
+    for (uint32_t b = 0; b < nBatch; b++)
+    {
+        for (uint32_t v = 0; v < vSize; v++)
+        {
+            outResult.Set(batchVector.Get() + vector.Get());
+            ++outResult;
+            ++vector;
+            ++batchVector;
+        }
+        vector -= vSize;
+    }
+    batchVector -= vSize * nBatch;
+    outResult -= vSize * nBatch;
+}
+
+
+// Layer norm for each batch.
+// normalization_epsilon is added to avoid divergence.
+void MeanStddevNormalization(armnn::Decoder<float>& input_vector,
+                             armnn::Encoder<float>& output_vector,
+                             uint32_t v_size,
+                             uint32_t n_batch,
+                             float normalization_epsilon)
+{
+    for (uint32_t batch = 0; batch < n_batch; ++batch) {
+        float sum = 0.0f;
+        float sum_sq = 0.0f;
+        for (uint32_t i = 0; i < v_size; ++i) {
+            sum += input_vector.Get();
+            sum_sq += input_vector.Get() * input_vector.Get();
+            ++input_vector;
+        }
+        input_vector -= v_size;
+
+        const float mean = sum / static_cast<float>(v_size);
+        float stddev_inv = 0.0f;
+        const float variance = sum_sq / static_cast<float>(v_size) - mean * mean;
+        if (variance == 0) {
+            stddev_inv = 1.0f / std::sqrt(normalization_epsilon);
+        } else {
+            stddev_inv = 1.0f / std::sqrt(variance);
+        }
+
+        for (uint32_t i = 0; i < v_size; ++i) {
+            output_vector.Set((input_vector.Get() - mean) * stddev_inv);
+            ++output_vector;
+            ++input_vector;
+        }
+        // Don't reset iterator to handle next batch
+    }
+    output_vector -= v_size * n_batch;
+    input_vector -= v_size * n_batch;
+}
+
+void ZeroVector(armnn::Encoder<float>& vector,
+                uint32_t vSize)
+{
+    for (uint32_t v = 0; v < vSize; v++)
+    {
+        vector.Set(0.0f);
+        ++vector;
+    }
+    vector -= vSize;
+}
+
+void MatrixBatchVectorMultiplyAccumulate(armnn::Decoder<float>& matrix,
+                                         uint32_t mRows,
+                                         uint32_t mCols,
+                                         armnn::Decoder<float>& vector,
+                                         uint32_t nBatch,
+                                         armnn::Encoder<float>& outResult)
+{
+    for (uint32_t b = 0; b < nBatch; b++)
+    {
+        for (uint32_t r = 0; r < mRows; r++)
+        {
+            vector += b * mCols;
+            for (uint32_t c = 0; c < mCols; c++)
+            {
+                outResult.Set(outResult.Get() + matrix.Get() * vector.Get());
+                ++matrix;
+                ++vector;
+            }
+            outResult += 1;
+            vector -= (b+1) * mCols;
+        }
+        matrix -= (mRows * mCols);
+    }
+    outResult -= (mRows * nBatch);
+}
+
+void VectorBatchVectorAssign(armnn::Decoder<float>& vector,
+                             uint32_t vSize,
+                             uint32_t nBatch,
+                             armnn::Encoder<float>& outBatchVector)
+{
+    for (uint32_t b = 0; b < nBatch; b++)
+    {
+        for (uint32_t v = 0; v < vSize; v++)
+        {
+            outBatchVector.Set(vector.Get());
+            ++outBatchVector;
+            ++vector;
+        }
+        vector -= vSize;
+    }
+    outBatchVector -= (nBatch * vSize);
+}
+
+void VectorBatchVectorCwiseProductAccumulate(armnn::Decoder<float>& vector,
+                                             uint32_t vSize,
+                                             armnn::Decoder<float>& batchVector,
+                                             uint32_t nBatch,
+                                             armnn::Encoder<float>& outResult)
+{
+    for (uint32_t b = 0; b < nBatch; b++)
+    {
+        for (uint32_t v = 0; v < vSize; v++)
+        {
+            outResult.Set(outResult.Get() + vector.Get() * batchVector.Get());
+            ++outResult;
+            ++vector;
+            ++batchVector;
+        }
+        vector -= vSize;
+    }
+    batchVector -= vSize * nBatch;
+    outResult -= vSize * nBatch;
+}
+
+void VectorBatchVectorCwiseProduct(armnn::Decoder<float>& vector,
+                                   uint32_t vSize,
+                                   armnn::Decoder<float>& batchVector,
+                                   uint32_t nBatch,
+                                   armnn::Encoder<float>& outResult)
+{
+    for (uint32_t b = 0; b < nBatch; b++)
+    {
+        for (uint32_t v = 0; v < vSize; v++)
+        {
+            outResult.Set(vector.Get() * batchVector.Get());
+            ++outResult;
+            ++vector;
+            ++batchVector;
+        }
+        vector -= vSize;
+    }
+    batchVector -= vSize * nBatch;
+    outResult -= vSize * nBatch;
+}
+
+void Sub1Vector(armnn::Decoder<float>& vector,
+                uint32_t vSize,
+                armnn::Encoder<float>& result)
+{
+    for (uint32_t v = 0; v < vSize; v++)
+    {
+        result.Set(1.0f - vector.Get());
+        ++vector;
+        ++result;
+    }
+    vector -= vSize;
+    result -= vSize;
+}
+
+void VectorVectorCwiseProduct(armnn::Decoder<float>& vector1,
+                              armnn::Decoder<float>& vector2,
+                              uint32_t vSize,
+                              armnn::Encoder<float>& outResult)
+{
+    for (uint32_t v = 0; v < vSize; v++)
+    {
+        outResult.Set(vector1.Get() * vector2.Get());
+        ++outResult;
+        ++vector1;
+        ++vector2;
+    }
+    outResult -= vSize;
+    vector1 -= vSize;
+    vector2 -= vSize;
+}
+
+void VectorVectorCwiseProductAccumulate(armnn::Decoder<float>& vector1,
+                                        armnn::Decoder<float>& vector2,
+                                        uint32_t vSize,
+                                        armnn::Encoder<float>& outResult)
+{
+    for (uint32_t v = 0; v < vSize; v++)
+    {
+        outResult.Set(outResult.Get() + vector1.Get() * vector2.Get());
+        ++outResult;
+        ++vector1;
+        ++vector2;
+    }
+    outResult -= vSize;
+    vector1 -= vSize;
+    vector2 -= vSize;
+}
+
+float Clip(float f,
+           float absLimit)
+{
+    float result = (absLimit < f) ? absLimit : f;
+    result = (-absLimit > result) ? -absLimit : result;
+    return result;
+}
+
+void ClipVector(armnn::Decoder<float>& vector,
+                uint32_t vSize,
+                float absLimit,
+                armnn::Encoder<float>& outResult)
+{
+    for (uint32_t v = 0; v < vSize; v++)
+    {
+        outResult.Set(Clip(vector.Get(), absLimit));
+        ++vector;
+        ++outResult;
+    }
+    vector -= vSize;
+    outResult -= vSize;
+}
+
+void CopyVector(armnn::Decoder<float>& vector,
+                uint32_t vSize,
+                armnn::Encoder<float>& outResult)
+{
+    for (uint32_t v = 0; v < vSize; v++)
+    {
+        outResult.Set(vector.Get());
+        ++outResult;
+        ++vector;
+    }
+    outResult -= vSize;
+    vector -= vSize;
+}
+
+void SetActivationParameters(uint32_t activation,
+                             armnn::ActivationFunction& outArmnnActivation,
+                             float& outA,
+                             float& outB)
+{
+    switch (activation)
+    {
+        case 0: // None
+            outA = 0;
+            outB = 0;
+            return;
+
+        case 1: // Relu
+            outArmnnActivation = armnn::ActivationFunction::ReLu;
+            outA = 0;
+            outB = 0;
+            return;
+
+        case 3: // Relu6
+            outArmnnActivation = armnn::ActivationFunction::BoundedReLu;
+            outA = 6;
+            outB = 0;
+            return;
+
+        case 4: // Tanh
+            outArmnnActivation = armnn::ActivationFunction::TanH;
+            outA = 1;
+            outB = 1;
+            return;
+
+        case 6: // Sigmoid
+            outArmnnActivation = armnn::ActivationFunction::Sigmoid;
+            outA = 0;
+            outB = 0;
+            return;
+
+        default:
+            throw armnn::Exception("Unsupported activation function: " + std::to_string(activation));
+    }
+}
+
+std::unique_ptr<armnn::ScopedCpuTensorHandle> AssignScopedCpuTensorHandle(const armnn::ConstCpuTensorHandle* ptr)
+{
+    if (!ptr)
+    {
+        return nullptr;
+    }
+
+    return std::make_unique<armnn::ScopedCpuTensorHandle>(*ptr);
+}
diff --git a/src/backends/reference/workloads/LstmUtils.hpp b/src/backends/reference/workloads/LstmUtils.hpp
index db02a84..f6aff8b 100644
--- a/src/backends/reference/workloads/LstmUtils.hpp
+++ b/src/backends/reference/workloads/LstmUtils.hpp
@@ -8,211 +8,81 @@
 #include "BaseIterator.hpp"
 #include <backendsCommon/CpuTensorHandle.hpp>
 
-namespace
-{
-
 // Helper functions ported from the Android code base
 // Refer to: android/external/tensorflow/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
 
+
+void VectorBatchVectorAdd(armnn::Decoder<float>& vector,
+                          uint32_t vSize,
+                          armnn::Decoder<float>& batchVector,
+                          uint32_t nBatch,
+                          armnn::Encoder<float>& outResult );
+
+// Layer norm for each batch.
+// normalization_epsilon is added to avoid divergence.
+void MeanStddevNormalization(armnn::Decoder<float>& input_vector,
+                             armnn::Encoder<float>& output_vector,
+                             uint32_t v_size,
+                             uint32_t n_batch,
+                             float normalization_epsilon);
+
+void ZeroVector(armnn::Encoder<float>& vector,
+                uint32_t vSize);
+
 void MatrixBatchVectorMultiplyAccumulate(armnn::Decoder<float>& matrix,
                                          uint32_t mRows,
                                          uint32_t mCols,
                                          armnn::Decoder<float>& vector,
                                          uint32_t nBatch,
-                                         armnn::Encoder<float>& outResult)
-{
-    for (uint32_t b = 0; b < nBatch; b++)
-    {
-        for (uint32_t r = 0; r < mRows; r++)
-        {
-            vector += b * mCols;
-            for (uint32_t c = 0; c < mCols; c++)
-            {
-                outResult.Set(outResult.Get() + matrix.Get() * vector.Get());
-                ++matrix;
-                ++vector;
-            }
-            outResult += 1;
-            vector -= (b+1) * mCols;
-        }
-        matrix -= (mRows * mCols);
-    }
-    outResult -= (mRows * nBatch);
-}
+                                         armnn::Encoder<float>& outResult);
 
 void VectorBatchVectorAssign(armnn::Decoder<float>& vector,
                              uint32_t vSize,
                              uint32_t nBatch,
-                             armnn::Encoder<float>& outBatchVector)
-{
-    for (uint32_t b = 0; b < nBatch; b++)
-    {
-        for (uint32_t v = 0; v < vSize; v++)
-        {
-            outBatchVector.Set(vector.Get());
-            ++outBatchVector;
-            ++vector;
-        }
-        vector -= vSize;
-    }
-    outBatchVector -= (nBatch * vSize);
-}
+                             armnn::Encoder<float>& outBatchVector);
 
 void VectorBatchVectorCwiseProductAccumulate(armnn::Decoder<float>& vector,
                                              uint32_t vSize,
                                              armnn::Decoder<float>& batchVector,
                                              uint32_t nBatch,
-                                             armnn::Encoder<float>& outResult)
-{
-    for (uint32_t b = 0; b < nBatch; b++)
-    {
-        for (uint32_t v = 0; v < vSize; v++)
-        {
-            outResult.Set(outResult.Get() + vector.Get() * batchVector.Get());
-            ++outResult;
-            ++vector;
-            ++batchVector;
-        }
-        vector -= vSize;
-    }
-    batchVector -= vSize * nBatch;
-    outResult -= vSize * nBatch;
-}
+                                             armnn::Encoder<float>& outResult);
+
+void VectorBatchVectorCwiseProduct(armnn::Decoder<float>& vector,
+                                   uint32_t vSize,
+                                   armnn::Decoder<float>& batchVector,
+                                   uint32_t nBatch,
+                                   armnn::Encoder<float>& outResult);
 
 void Sub1Vector(armnn::Decoder<float>& vector,
                 uint32_t vSize,
-                armnn::Encoder<float>& result)
-{
-    for (uint32_t v = 0; v < vSize; v++)
-    {
-        result.Set(1.0f - vector.Get());
-        ++vector;
-        ++result;
-    }
-    vector -= vSize;
-    result -= vSize;
-}
+                armnn::Encoder<float>& result);
+
 
 void VectorVectorCwiseProduct(armnn::Decoder<float>& vector1,
                               armnn::Decoder<float>& vector2,
                               uint32_t vSize,
-                              armnn::Encoder<float>& outResult)
-{
-    for (uint32_t v = 0; v < vSize; v++)
-    {
-        outResult.Set(vector1.Get() * vector2.Get());
-        ++outResult;
-        ++vector1;
-        ++vector2;
-    }
-    outResult -= vSize;
-    vector1 -= vSize;
-    vector2 -= vSize;
-}
+                              armnn::Encoder<float>& outResult);
 
 void VectorVectorCwiseProductAccumulate(armnn::Decoder<float>& vector1,
                                         armnn::Decoder<float>& vector2,
                                         uint32_t vSize,
-                                        armnn::Encoder<float>& outResult)
-{
-    for (uint32_t v = 0; v < vSize; v++)
-    {
-        outResult.Set(outResult.Get() + vector1.Get() * vector2.Get());
-        ++outResult;
-        ++vector1;
-        ++vector2;
-    }
-    outResult -= vSize;
-    vector1 -= vSize;
-    vector2 -= vSize;
-}
+                                        armnn::Encoder<float>& outResult);
 
 float Clip(float f,
-           float absLimit)
-{
-    float result = (absLimit < f) ? absLimit : f;
-    result = (-absLimit > result) ? -absLimit : result;
-    return result;
-}
+           float absLimit);
 
 void ClipVector(armnn::Decoder<float>& vector,
                 uint32_t vSize,
                 float absLimit,
-                armnn::Encoder<float>& outResult)
-{
-    for (uint32_t v = 0; v < vSize; v++)
-    {
-        outResult.Set(Clip(vector.Get(), absLimit));
-        ++vector;
-        ++outResult;
-    }
-    vector -= vSize;
-    outResult -= vSize;
-}
+                armnn::Encoder<float>& outResult);
 
 void CopyVector(armnn::Decoder<float>& vector,
                 uint32_t vSize,
-                armnn::Encoder<float>& outResult)
-{
-    for (uint32_t v = 0; v < vSize; v++)
-    {
-        outResult.Set(vector.Get());
-        ++outResult;
-        ++vector;
-    }
-    outResult -= vSize;
-    vector -= vSize;
-}
+                armnn::Encoder<float>& outResult);
 
 void SetActivationParameters(uint32_t activation,
                              armnn::ActivationFunction& outArmnnActivation,
                              float& outA,
-                             float& outB)
-{
-    switch (activation)
-    {
-    case 0: // None
-        outA = 0;
-        outB = 0;
-        return;
+                             float& outB);
 
-    case 1: // Relu
-        outArmnnActivation = armnn::ActivationFunction::ReLu;
-        outA = 0;
-        outB = 0;
-        return;
-
-    case 3: // Relu6
-        outArmnnActivation = armnn::ActivationFunction::BoundedReLu;
-        outA = 6;
-        outB = 0;
-        return;
-
-    case 4: // Tanh
-        outArmnnActivation = armnn::ActivationFunction::TanH;
-        outA = 1;
-        outB = 1;
-        return;
-
-    case 6: // Sigmoid
-        outArmnnActivation = armnn::ActivationFunction::Sigmoid;
-        outA = 0;
-        outB = 0;
-        return;
-
-    default:
-        throw armnn::Exception("Unsupported activation function: " + std::to_string(activation));
-    }
-}
-
-std::unique_ptr<armnn::ScopedCpuTensorHandle> AssignScopedCpuTensorHandle(const armnn::ConstCpuTensorHandle* ptr)
-{
-    if (!ptr)
-    {
-        return nullptr;
-    }
-
-    return std::make_unique<armnn::ScopedCpuTensorHandle>(*ptr);
-}
-
-} // anonymous namespace
+std::unique_ptr<armnn::ScopedCpuTensorHandle> AssignScopedCpuTensorHandle(const armnn::ConstCpuTensorHandle* ptr);
diff --git a/src/backends/reference/workloads/RefLstmWorkload.cpp b/src/backends/reference/workloads/RefLstmWorkload.cpp
index f8ebc58..70b3443 100644
--- a/src/backends/reference/workloads/RefLstmWorkload.cpp
+++ b/src/backends/reference/workloads/RefLstmWorkload.cpp
@@ -32,6 +32,10 @@
     , m_OutputGateBiasTensor          (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias))
     , m_ProjectionWeightsTensor       (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights))
     , m_ProjectionBiasTensor          (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias))
+    , m_InputLayerNormWeights         (AssignScopedCpuTensorHandle(descriptor.m_InputLayerNormWeights))
+    , m_ForgetLayerNormWeights        (AssignScopedCpuTensorHandle(descriptor.m_ForgetLayerNormWeights))
+    , m_CellLayerNormWeights          (AssignScopedCpuTensorHandle(descriptor.m_CellLayerNormWeights))
+    , m_OutputLayerNormWeights        (AssignScopedCpuTensorHandle(descriptor.m_OutputLayerNormWeights))
 {}
 
 void RefLstmWorkload::Execute() const
@@ -62,8 +66,9 @@
     const uint32_t nCell   = m_InputToOutputWeightsTensor->GetShape()[0];
     const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1];
 
-    const bool useCifg     = m_Data.m_Parameters.m_CifgEnabled;
-    const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
+    const bool useCifg      = m_Data.m_Parameters.m_CifgEnabled;
+    const bool usePeephole  = m_Data.m_Parameters.m_PeepholeEnabled;
+    const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
 
     // Index the scratch buffers pointers to the global scratch buffer.
     std::unique_ptr<Encoder<float>> inputGateScratch  = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
@@ -134,6 +139,26 @@
     std::unique_ptr<Decoder<float>> projectionWeightsTensor;
     std::unique_ptr<Decoder<float>> projectionBiasTensor;
 
+    std::unique_ptr<Decoder<float>> inputLayerNormWeights;
+    std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
+    std::unique_ptr<Decoder<float>> cellLayerNormWeights;
+    std::unique_ptr<Decoder<float>> outputLayerNormWeights;
+
+    if (useLayerNorm)
+    {
+        if (!useCifg)
+        {
+            inputLayerNormWeights = MakeDecoder<float>(
+                    m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetTensor<void>());
+        }
+        forgetLayerNormWeights = MakeDecoder<float>(
+                m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetTensor<void>());
+        cellLayerNormWeights = MakeDecoder<float>(
+                m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetTensor<void>());
+        outputLayerNormWeights = MakeDecoder<float>(
+                m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetTensor<void>());
+    }
+
     if (!useCifg)
     {
         inputToInputWeightsTensor = MakeDecoder<float>(
@@ -169,18 +194,32 @@
         }
     }
 
-    // Initialize scratch buffers with bias.
-    if (!useCifg)
+    if (!useLayerNorm)
     {
-        VectorBatchVectorAssign(*inputGateBiasTensor,
-                                nCell, nBatch, *inputGateScratch);
+        // Initialize scratch buffers with bias.
+        if (!useCifg)
+        {
+            VectorBatchVectorAssign(*inputGateBiasTensor,
+                                    nCell, nBatch, *inputGateScratch);
+        }
+        VectorBatchVectorAssign(*forgetGateBiasTensor,
+                                nCell, nBatch, *forgetGateScratch);
+        VectorBatchVectorAssign(*cellBiasTensor,
+                                nCell, nBatch, *cellScratch);
+        VectorBatchVectorAssign(*outputGateBiasTensor,
+                                nCell, nBatch, *outputGateScratch);
     }
-    VectorBatchVectorAssign(*forgetGateBiasTensor,
-                            nCell, nBatch, *forgetGateScratch);
-    VectorBatchVectorAssign(*cellBiasTensor,
-                            nCell, nBatch, *cellScratch);
-    VectorBatchVectorAssign(*outputGateBiasTensor,
-                            nCell, nBatch, *outputGateScratch);
+    else
+    {
+        // Initialize scratch buffers with zeroes.
+        if (!useCifg)
+        {
+            ZeroVector(*inputGateScratch, nCell * nBatch);
+        }
+        ZeroVector(*forgetGateScratch, nCell * nBatch);
+        ZeroVector(*cellScratch      , nCell * nBatch);
+        ZeroVector(*outputGateScratch, nCell * nBatch);
+    }
 
     // For each batch and cell: compute input_weight * input.
     if (!useCifg)
@@ -216,6 +255,15 @@
             VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor,
                                                     nCell, *cellStateIn, nBatch, *inputGateScratch);
         }
+        if (useLayerNorm)
+        {
+            MeanStddevNormalization(*inputGateScratchDecoder,
+                                    *inputGateScratch, nCell, nBatch, m_LayerNormEpsilon);
+            VectorBatchVectorCwiseProduct(*inputLayerNormWeights,
+                                          nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
+            VectorBatchVectorAdd(*inputGateBiasTensor,
+                                 nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
+        }
         Activation(*inputGateScratchDecoder, *inputGateScratch,
                    TensorInfo({nCell, nBatch}, outputType),
                    ActivationFunction::Sigmoid, 0, 0);
@@ -227,11 +275,30 @@
         VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell,
                                                 *cellStateIn, nBatch, *forgetGateScratch);
     }
+    if (useLayerNorm)
+    {
+        MeanStddevNormalization(*forgetGateScratchDecoder,
+                                *forgetGateScratch, nCell, nBatch, m_LayerNormEpsilon);
+        VectorBatchVectorCwiseProduct(*forgetLayerNormWeights,
+                                      nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
+        VectorBatchVectorAdd(*forgetGateBiasTensor,
+                             nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
+    }
     Activation(*forgetGateScratchDecoder, *forgetGateScratch,
                TensorInfo({nCell, nBatch}, outputType),
                ActivationFunction::Sigmoid, 0, 0);
 
     // For each batch and cell: update the cell.
+    if (useLayerNorm)
+    {
+        MeanStddevNormalization(*cellScratchDecoder,
+                                *cellScratch, nCell, nBatch, m_LayerNormEpsilon);
+        VectorBatchVectorCwiseProduct(*cellLayerNormWeights,
+                                      nCell, *cellScratchDecoder, nBatch, *cellScratch);
+        VectorBatchVectorAdd(*cellBiasTensor,
+                             nCell, *cellScratchDecoder, nBatch, *cellScratch);
+    }
+
     VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut);
 
     ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid;
@@ -267,6 +334,15 @@
         VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor,
                                                 nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
     }
+    if (useLayerNorm)
+    {
+        MeanStddevNormalization(*outputGateScratchDecoder,
+                                *outputGateScratch, nCell, nBatch, m_LayerNormEpsilon);
+        VectorBatchVectorCwiseProduct(*outputLayerNormWeights,
+                                      nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
+        VectorBatchVectorAdd(*outputGateBiasTensor,
+                             nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
+    }
     Activation(*outputGateScratchDecoder, *outputGateScratch,
                TensorInfo({nCell, nBatch}, outputType),
                ActivationFunction::Sigmoid, 0, 0);
diff --git a/src/backends/reference/workloads/RefLstmWorkload.hpp b/src/backends/reference/workloads/RefLstmWorkload.hpp
index 38e3fb9..ce5a775 100644
--- a/src/backends/reference/workloads/RefLstmWorkload.hpp
+++ b/src/backends/reference/workloads/RefLstmWorkload.hpp
@@ -38,6 +38,12 @@
     std::unique_ptr<ScopedCpuTensorHandle> m_OutputGateBiasTensor;
     std::unique_ptr<ScopedCpuTensorHandle> m_ProjectionWeightsTensor;
     std::unique_ptr<ScopedCpuTensorHandle> m_ProjectionBiasTensor;
+    std::unique_ptr<ScopedCpuTensorHandle> m_InputLayerNormWeights;
+    std::unique_ptr<ScopedCpuTensorHandle> m_ForgetLayerNormWeights;
+    std::unique_ptr<ScopedCpuTensorHandle> m_CellLayerNormWeights;
+    std::unique_ptr<ScopedCpuTensorHandle> m_OutputLayerNormWeights;
+
+    float m_LayerNormEpsilon = static_cast<float>(1e-8);
 };
 
 } //namespace armnn