COMPMID-1781 Add channel support in CLL2Normalization

Change-Id: Ibab049f09413258c99335b7da6b151530a1bd136
diff --git a/tests/validation/reference/L2NormalizeLayer.cpp b/tests/validation/reference/L2NormalizeLayer.cpp
index 2667751..fcd6226 100644
--- a/tests/validation/reference/L2NormalizeLayer.cpp
+++ b/tests/validation/reference/L2NormalizeLayer.cpp
@@ -57,24 +57,38 @@
     SimpleTensor<T> sum = reduction_operation(src, get_output_shape(src.shape(), axis), axis, ReductionOperation::SUM_SQUARE);
 
     // Compute reference
-    const int elems      = src.shape()[axis];
-    const int upper_dims = src.shape().total_size_upper(axis + 1);
+    const int upper_dims     = src.shape().total_size_upper(axis + 1);
+    const int lower_dims     = src.shape().total_size_lower(axis + 1);
+    const int lower_dims_sum = sum.shape().total_size_lower(axis + 1);
 
     for(int du = 0; du < upper_dims; ++du)
     {
-        if(axis == 0)
+        const T *src_row_ptr = src.data() + du * lower_dims;
+        T       *dst_row_ptr = dst.data() + du * lower_dims;
+        switch(axis)
         {
-            const T *src_row_ptr         = src.data() + du * elems;
-            T       *dst_row_ptr         = dst.data() + du * elems;
-            const T  normalization_value = sqrt(std::max(sum[du], static_cast<T>(epsilon)));
-            std::transform(src_row_ptr, src_row_ptr + elems, dst_row_ptr, [normalization_value](T val)
+            case 0:
             {
-                return val / normalization_value;
-            });
-        }
-        else
-        {
-            ARM_COMPUTE_ERROR("Unsupported normalization axis");
+                const int elems               = src.shape()[0];
+                const T   normalization_value = sqrt(std::max(sum[du], static_cast<T>(epsilon)));
+                std::transform(src_row_ptr, src_row_ptr + elems, dst_row_ptr, [normalization_value](T val)
+                {
+                    return val / normalization_value;
+                });
+            }
+            break;
+            case 1:
+            case 2:
+            {
+                for(int ld = 0; ld < lower_dims; ++ld)
+                {
+                    const T normalization_value = sqrt(std::max(sum[ld % lower_dims_sum + du * lower_dims_sum], static_cast<T>(epsilon)));
+                    dst_row_ptr[ld]             = src_row_ptr[ld] / normalization_value;
+                }
+            }
+            break;
+            default:
+                ARM_COMPUTE_ERROR("Axis not supported");
         }
     }