Removing rank 0 broadcast in binary op.

Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I14bec5020c91f7abd6c1adc31068a22961330a97
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index c33f646..a11d855 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -32,10 +32,8 @@
     setRequiredOperands(2, 1);
     setRequiredRank(0, 6);
 
-    a_rank = b_rank = max_input_rank = -1;
-    a = b   = nullptr;
-    a_rank0 = b_rank0 = nullptr;
-    result            = nullptr;
+    a = b  = nullptr;
+    result = nullptr;
 
     fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
 }
@@ -55,54 +53,37 @@
         return 1;
     }
 
-    a_rank = inputs[0]->getRank();
-    b_rank = inputs[1]->getRank();
-    if (a_rank != 0 && b_rank != 0 && a_rank != b_rank)
-    {
-        printNodeValidationError("Binary operator input ranks must match");
-        return 1;
-    }
-
-    max_input_rank = a_rank >= b_rank ? a_rank : b_rank;
-
-    // A & B must be the same types
-    if (inputs[0]->matchType(*inputs[1]))
+    // A & B must be the same rank and types
+    if (inputs[0]->matchRankType(*inputs[1]))
     {
         printNodeValidationError("Binary operator input types must match");
         return 1;
     }
 
-    // Result's geometry must match, but the type may be wider
-    if (outputs[0]->getRank() != max_input_rank)
+    // Input and output rank must match
+    // If it's not MUL, type also needs to match as well.
+    if (nodeType != Op_MUL)
     {
-        printNodeValidationError("Binary operator input and output genometry must match");
-        return 1;
-    }
-
-    if (a_rank == max_input_rank)
-    {
-        a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+        if (inputs[0]->matchRankType(*outputs[0]))
+        {
+            printNodeValidationError("Binary operators (except MUL) input and output rank and type must match");
+            return 1;
+        }
     }
     else
     {
-        a_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[0]);
+        if (inputs[0]->matchRank(*outputs[0]))
+        {
+            printNodeValidationError("MUL operator input and output rank must match");
+            return 1;
+        }
     }
 
-    if (b_rank == max_input_rank)
-    {
-        b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
-    }
-    else
-    {
-        b_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[1]);
-    }
-
+    a      = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+    b      = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
     result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
 
-    // either a or b can be rank0
-    // a_rank0 and b_rank0 can't be valid at the same time.
-    // if a and be are both rank0, they should be evaulated as 'a' and 'b', instead of 'a_rank0' and 'b_rank0'
-    ASSERT_MEM((a || a_rank0) && (b || b_rank0) && !(a_rank0 && b_rank0) && result);
+    ASSERT_MEM(a && b && result);
 
     return 0;
 }
@@ -114,25 +95,10 @@
 
     std::vector<int> a_shape, b_shape;
 
-    if (a_rank == max_input_rank)
-    {
-        a_shape = a->getShape();
-    }
-    else
-    {
-        a_shape.assign(max_input_rank, 1);
-    }
+    a_shape = a->getShape();
+    b_shape = b->getShape();
 
-    if (b_rank == max_input_rank)
-    {
-        b_shape = b->getShape();
-    }
-    else
-    {
-        b_shape.assign(max_input_rank, 1);
-    }
-
-    for (int i = 0; i < max_input_rank; i++)
+    for (int i = 0; i < (int)a_shape.size(); i++)
     {
         if (a_shape[i] != output_shape[i] && a_shape[i] == 1)
         {
@@ -164,23 +130,8 @@
     reshaper.fill(1);
     TIn ia, ib;
 
-    if (this->a_rank == this->max_input_rank)
-    {
-        ia = this->a->getTensor().broadcast(this->bcast_a);
-    }
-    else
-    {
-        ia = this->a_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_a);
-    }
-
-    if (this->b_rank == this->max_input_rank)
-    {
-        ib = this->b->getTensor().broadcast(this->bcast_b);
-    }
-    else
-    {
-        ib = this->b_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_b);
-    }
+    ia = this->a->getTensor().broadcast(this->bcast_a);
+    ib = this->b->getTensor().broadcast(this->bcast_b);
 
     this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
 
@@ -475,7 +426,7 @@
                 }
                 else
                 {
-                    result = static_cast<int64_t>(a) * b;
+                    result                = static_cast<int64_t>(a) * b;
                     int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
                     int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
                     REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");