COMPMID-1518: Add support for GEMM3D in CLGEMMLowpMatrixMultiplyCore

Change-Id: Ib14ac821ee5d4aff80bd602cd3e76e7018abb5e6
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/150268
Tested-by: bsgcomp <bsgcomp@arm.com>
Reviewed-by: Isabella Gottardi <isabella.gottardi@arm.com>
Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com>
diff --git a/tests/validation/reference/GEMMLowp.cpp b/tests/validation/reference/GEMMLowp.cpp
index 8e41aef..9a7e409 100644
--- a/tests/validation/reference/GEMMLowp.cpp
+++ b/tests/validation/reference/GEMMLowp.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -98,41 +98,52 @@
 } // namespace
 
 template <typename T_out, typename T_in>
-SimpleTensor<T_out> gemmlowp_matrix_multiply_core(const SimpleTensor<T_in> &a, const SimpleTensor<T_in> &b, int32_t a_offset, int32_t b_offset)
+SimpleTensor<T_out> gemmlowp_matrix_multiply_core(const SimpleTensor<T_in> &a, const SimpleTensor<T_in> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset)
 {
     static_assert(std::is_same<typename std::decay<T_out>::type, int32_t>::value, "Only int32_t is allowed for the output");
 
-    TensorShape         shape(b.shape()[0], a.shape()[1]);
     DataType            dt = std::is_same<T_out, int32_t>::value ? DataType::S32 : DataType::U32;
-    SimpleTensor<T_out> c(shape, dt);
+    SimpleTensor<T_out> c(shape_c, dt);
 
-    const int K       = a.shape().x();
-    const int b_width = b.shape().x();
-    const int rows    = c.shape().y(); //M
-    const int cols    = c.shape().x(); //N
+    const int K = a.shape().x();
+    const int M = a.shape().y();
+    const int N = b.shape().x();
+    const int D = a.shape().z(); // Number of matrices in a batch
+
+    const int a_stride_z = K * M;
+    // Do not slide the matrix B along the 3rd dimension in case matrix B has less than 3 dimensions
+    const int b_stride_z = b.shape().num_dimensions() > 2 ? N * K : 0;
+    const int c_stride_z = N * M;
 
     std::vector<T_out> acc;
-    acc.resize(cols);
+    acc.resize(N);
 
-    for(int i = 0; i < rows; ++i)
+    for(int depth = 0; depth < D; ++depth)
     {
-        for(int j = 0; j < cols; ++j)
+        const int base_addr_a = depth * a_stride_z;
+        const int base_addr_b = depth * b_stride_z;
+        const int base_addr_c = depth * c_stride_z;
+
+        for(int i = 0; i < M; ++i)
         {
-            acc[j] = 0;
-        }
-        for(int k = 0; k < K; ++k)
-        {
-            const T_out tmp_a = a_offset + static_cast<T_out>(a[k + i * K]);
-            for(int j = 0; j < b_width; ++j)
+            for(int j = 0; j < N; ++j)
             {
-                const T_out tmp_b       = b_offset + static_cast<T_out>(b[j + k * b_width]);
-                const T_out mult_as_int = tmp_a * tmp_b;
-                acc[j] += mult_as_int;
+                acc[j] = 0;
             }
-        }
-        for(int j = 0; j < cols; ++j)
-        {
-            c[j + i * cols] = acc[j];
+            for(int k = 0; k < K; ++k)
+            {
+                const T_out tmp_a = a_offset + static_cast<T_out>(a[base_addr_a + k + i * K]);
+                for(int j = 0; j < N; ++j)
+                {
+                    const T_out tmp_b       = b_offset + static_cast<T_out>(b[base_addr_b + j + k * N]);
+                    const T_out mult_as_int = tmp_a * tmp_b;
+                    acc[j] += mult_as_int;
+                }
+            }
+            for(int j = 0; j < N; ++j)
+            {
+                c[base_addr_c + j + i * N] = acc[j];
+            }
         }
     }
 
@@ -141,9 +152,9 @@
 
 // used to validate assembly kernels which don't know anything about offsets
 template <typename T1, typename T2>
-SimpleTensor<T1> gemmlowp(const SimpleTensor<T2> &a, const SimpleTensor<T2> &b)
+SimpleTensor<T1> gemmlowp(const SimpleTensor<T2> &a, const SimpleTensor<T2> &b, TensorShape shape_c)
 {
-    return gemmlowp_matrix_multiply_core<T1, T2>(a, b, 0, 0);
+    return gemmlowp_matrix_multiply_core<T1, T2>(a, b, shape_c, 0, 0);
 }
 
 template <typename T>
@@ -198,10 +209,10 @@
                                                                            int32_t max);
 template SimpleTensor<uint8_t> gemmlowp_quantize_down_int32_to_uint8_scale(const SimpleTensor<int32_t> &a, const SimpleTensor<int32_t> &b, int32_t result_offset, int32_t result_mult_int,
                                                                            int32_t result_shift, int32_t min, int32_t max);
-template SimpleTensor<int32_t> gemmlowp_matrix_multiply_core(const SimpleTensor<int8_t> &a, const SimpleTensor<int8_t> &b, int32_t a_offset, int32_t b_offset);
-template SimpleTensor<int32_t> gemmlowp_matrix_multiply_core(const SimpleTensor<uint8_t> &a, const SimpleTensor<uint8_t> &b, int32_t a_offset, int32_t b_offset);
-template SimpleTensor<int32_t> gemmlowp(const SimpleTensor<int8_t> &a, const SimpleTensor<int8_t> &b);
-template SimpleTensor<int32_t> gemmlowp(const SimpleTensor<uint8_t> &a, const SimpleTensor<uint8_t> &b);
+template SimpleTensor<int32_t> gemmlowp_matrix_multiply_core(const SimpleTensor<int8_t> &a, const SimpleTensor<int8_t> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset);
+template SimpleTensor<int32_t> gemmlowp_matrix_multiply_core(const SimpleTensor<uint8_t> &a, const SimpleTensor<uint8_t> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset);
+template SimpleTensor<int32_t> gemmlowp(const SimpleTensor<int8_t> &a, const SimpleTensor<int8_t> &b, TensorShape shape_c);
+template SimpleTensor<int32_t> gemmlowp(const SimpleTensor<uint8_t> &a, const SimpleTensor<uint8_t> &b, TensorShape shape_c);
 } // namespace reference
 } // namespace validation
 } // namespace test