adding batch dimension to MatMul.

Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I83f75dd5beb60fe7ca2d573ea0f81bac4cd62a07
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index b8c7ade..0007553 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -742,7 +742,7 @@
     : GraphNode(Op_MATMUL, id_)
 {
     setRequiredOperands(2, 1);
-    setRequiredRank(2);
+    setRequiredRank(3);
 
     INIT_QINFO(MatMul);
 }
@@ -765,16 +765,47 @@
         return 1;
     }
 
-    a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
-    b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+    a      = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+    b      = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+    output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
 
-    if (a->getShape()[1] != b->getShape()[0])
+    ASSERT_MEM(a && b && output);
+
+    // a: [N, H, C]
+    // b: [N, C, W]
+    // c: [N, H, W]
+
+    // Check N
+    if (a->getShape()[0] != b->getShape()[0] || a->getShape()[0] != output->getShape()[0])
     {
-        printNodeValidationError("OpMatMul operator a.shape[1] should match b.shape[0]");
+        printNodeValidationError("OpMatMul operator a.shape[0], b.shape[0] and output.shape[0] should match");
         return 1;
     }
+    N = a->getShape()[0];
 
-    c = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+    // Check C
+    if (a->getShape()[2] != b->getShape()[1])
+    {
+        printNodeValidationError("OpMatMul operator a.shape[2] should match b.shape[1]");
+        return 1;
+    }
+    C = a->getShape()[2];
+
+    // Check H
+    if (a->getShape()[1] != output->getShape()[1])
+    {
+        printNodeValidationError("OpMatMul operator a.shape[1] should match output.shape[1]");
+        return 1;
+    }
+    H = a->getShape()[1];
+
+    // Check W
+    if (b->getShape()[2] != output->getShape()[2])
+    {
+        printNodeValidationError("OpMatMul operator output.shape[2] should match output.shape[2]");
+        return 1;
+    }
+    W = b->getShape()[2];
 
     return 0;
 }
@@ -793,12 +824,42 @@
         b_val = b_val - (InEigenType)this->qinfo->b_zp();
     }
 
-    this->c->getTensor() = a_val.template cast<AccEigenType>().contract(b_val.template cast<AccEigenType>(), dims);
+    Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C });
+    Eigen::array<Eigen::Index, 2> b_rank2_shape({ C, W });
+    Eigen::array<Eigen::Index, 3> output_rank3_shape({ 1, H, W });
+
+    Eigen::array<Eigen::Index, 3> a_size_array({ 1, H, C });
+    Eigen::array<Eigen::Index, 3> b_size_array({ 1, C, W });
+
+    Eigen::array<Eigen::Index, 3> a_begin_array({ 0, 0, 0 });
+    Eigen::array<Eigen::Index, 3> b_begin_array({ 0, 0, 0 });
+
+    // Iterate N dimension.
+    for (int i = 0; i < N; i++)
+    {
+        a_begin_array[0] = i;
+        b_begin_array[0] = i;
+
+        TInRank2 a_rank2_val = a_val.slice(a_begin_array, a_size_array).reshape(a_rank2_shape);
+        TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape);
+        TAccRank2 output_rank2_val =
+            a_rank2_val.template cast<AccEigenType>().contract(b_rank2_val.template cast<AccEigenType>(), dims);
+        TAcc output_rank3_val = output_rank2_val.reshape(output_rank3_shape);
+        if (i == 0)
+        {
+            this->output->getTensor() = output_rank3_val;
+        }
+        else
+        {
+            TAcc temp                 = this->output->getTensor().concatenate(output_rank3_val, 0);
+            this->output->getTensor() = temp;
+        }
+    }
 
     if (AccDtype == DType_INT48)
     {
-        this->c->getTensor() = this->c->getTensor().cwiseMax((AccEigenType)AccQMin);
-        this->c->getTensor() = this->c->getTensor().cwiseMin((AccEigenType)AccQMax);
+        this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+        this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
     }
 
     return GraphNode::eval();
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index 26ce84b..9aaa140 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -183,15 +183,21 @@
     static constexpr DType AccDtype  = GetAccDType<Dtype, Dtype>::value;
     using InEigenType                = typename GetEigenType<Dtype>::type;
     using AccEigenType               = typename GetEigenType<AccDtype>::type;
-    using TIn                        = Eigen::Tensor<InEigenType, 2>;
-    using TAcc                       = Eigen::Tensor<AccEigenType, 2>;
+    using TIn                        = Eigen::Tensor<InEigenType, 3>;
+    using TAcc                       = Eigen::Tensor<AccEigenType, 3>;
+    using TInRank2                   = Eigen::Tensor<InEigenType, 2>;
+    using TAccRank2                  = Eigen::Tensor<AccEigenType, 2>;
     static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
     static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
 
 protected:
     TosaReference::TensorTemplate<TIn>* a;
     TosaReference::TensorTemplate<TIn>* b;
-    TosaReference::TensorTemplate<TAcc>* c;
+    TosaReference::TensorTemplate<TAcc>* output;
+    int64_t N;
+    int64_t H;
+    int64_t W;
+    int64_t C;
     tosa::TosaMatMulQuantInfo* qinfo;
 };
 
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 5670d1b..6f9acf4 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -314,12 +314,12 @@
     def tgMatmul(testGen, op, rank):
         pl, const = op["operands"]
 
-        assert rank == 2
+        assert rank == 3
         assert pl == 2 and const == 0
 
         a_shape = testGen.makeShape(rank)
         b_oc = testGen.makeShape(1)[0]
-        b_shape = np.asarray([a_shape[1], b_oc])
+        b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
 
         return [a_shape, b_shape]
 
@@ -1994,7 +1994,7 @@
         "matmul": {
             "op": Op.MATMUL,
             "operands": (2, 0),
-            "rank": (2, 2),
+            "rank": (3, 3),
             "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
             "qgen": TosaQuantGen.qgMatmul,
             "types": TYPE_NARROW_INT_FP,
@@ -2630,11 +2630,11 @@
 
     @staticmethod
     def matmulOp(ser, a, b):
-        # a: M, K
-        # b: K, N
-        # out: M, N
+        # a: N, H, C
+        # b: N, C, W
+        # out: N, H, W
 
-        output_shape = [a.shape[0], b.shape[1]]
+        output_shape = [a.shape[0], a.shape[1], b.shape[2]]
 
         if a.dtype == DType.INT8:
             out_dtype = DType.INT32