adding batch dimension to MatMul.

Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I83f75dd5beb60fe7ca2d573ea0f81bac4cd62a07
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index b8c7ade..0007553 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -742,7 +742,7 @@
     : GraphNode(Op_MATMUL, id_)
 {
     setRequiredOperands(2, 1);
-    setRequiredRank(2);
+    setRequiredRank(3);
 
     INIT_QINFO(MatMul);
 }
@@ -765,16 +765,47 @@
         return 1;
     }
 
-    a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
-    b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+    a      = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+    b      = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+    output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
 
-    if (a->getShape()[1] != b->getShape()[0])
+    ASSERT_MEM(a && b && output);
+
+    // a: [N, H, C]
+    // b: [N, C, W]
+    // c: [N, H, W]
+
+    // Check N
+    if (a->getShape()[0] != b->getShape()[0] || a->getShape()[0] != output->getShape()[0])
     {
-        printNodeValidationError("OpMatMul operator a.shape[1] should match b.shape[0]");
+        printNodeValidationError("OpMatMul operator a.shape[0], b.shape[0] and output.shape[0] should match");
         return 1;
     }
+    N = a->getShape()[0];
 
-    c = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+    // Check C
+    if (a->getShape()[2] != b->getShape()[1])
+    {
+        printNodeValidationError("OpMatMul operator a.shape[2] should match b.shape[1]");
+        return 1;
+    }
+    C = a->getShape()[2];
+
+    // Check H
+    if (a->getShape()[1] != output->getShape()[1])
+    {
+        printNodeValidationError("OpMatMul operator a.shape[1] should match output.shape[1]");
+        return 1;
+    }
+    H = a->getShape()[1];
+
+    // Check W
+    if (b->getShape()[2] != output->getShape()[2])
+    {
+        printNodeValidationError("OpMatMul operator output.shape[2] should match output.shape[2]");
+        return 1;
+    }
+    W = b->getShape()[2];
 
     return 0;
 }
@@ -793,12 +824,42 @@
         b_val = b_val - (InEigenType)this->qinfo->b_zp();
     }
 
-    this->c->getTensor() = a_val.template cast<AccEigenType>().contract(b_val.template cast<AccEigenType>(), dims);
+    Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C });
+    Eigen::array<Eigen::Index, 2> b_rank2_shape({ C, W });
+    Eigen::array<Eigen::Index, 3> output_rank3_shape({ 1, H, W });
+
+    Eigen::array<Eigen::Index, 3> a_size_array({ 1, H, C });
+    Eigen::array<Eigen::Index, 3> b_size_array({ 1, C, W });
+
+    Eigen::array<Eigen::Index, 3> a_begin_array({ 0, 0, 0 });
+    Eigen::array<Eigen::Index, 3> b_begin_array({ 0, 0, 0 });
+
+    // Iterate N dimension.
+    for (int i = 0; i < N; i++)
+    {
+        a_begin_array[0] = i;
+        b_begin_array[0] = i;
+
+        TInRank2 a_rank2_val = a_val.slice(a_begin_array, a_size_array).reshape(a_rank2_shape);
+        TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape);
+        TAccRank2 output_rank2_val =
+            a_rank2_val.template cast<AccEigenType>().contract(b_rank2_val.template cast<AccEigenType>(), dims);
+        TAcc output_rank3_val = output_rank2_val.reshape(output_rank3_shape);
+        if (i == 0)
+        {
+            this->output->getTensor() = output_rank3_val;
+        }
+        else
+        {
+            TAcc temp                 = this->output->getTensor().concatenate(output_rank3_val, 0);
+            this->output->getTensor() = temp;
+        }
+    }
 
     if (AccDtype == DType_INT48)
     {
-        this->c->getTensor() = this->c->getTensor().cwiseMax((AccEigenType)AccQMin);
-        this->c->getTensor() = this->c->getTensor().cwiseMin((AccEigenType)AccQMax);
+        this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+        this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
     }
 
     return GraphNode::eval();