Check valid broadcastable shape for binary and ternary ops

Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I9ed3d8971a133b4cbb2cf7d827f4e69d55dee246
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index b199f69..287ad92 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -60,10 +60,18 @@
         return 1;
     }
 
-    if (inputs[0]->matchRank(*outputs[0]))
+    if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */))
     {
         std::string err =
-            "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match";
+            "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " lhs input and output rank/shape must match";
+        printNodeValidationError(err.c_str());
+        return 1;
+    }
+
+    if (inputs[1]->matchRankShape(*outputs[0], true /* broadcastOk */))
+    {
+        std::string err =
+            "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " rhs input and output rank/shape must match";
         printNodeValidationError(err.c_str());
         return 1;
     }
@@ -82,31 +90,14 @@
 template <int Rank, DType InDtype, DType OutDtype>
 int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
 {
-    auto output_shape = result->getTensor().dimensions();
+    const std::vector<int>& a_shape      = a->getShape();
+    const std::vector<int>& b_shape      = b->getShape();
+    const std::vector<int>& output_shape = result->getShape();
 
-    std::vector<int> a_shape, b_shape;
-
-    a_shape = a->getShape();
-    b_shape = b->getShape();
-
-    for (int i = 0; i < (int)a_shape.size(); i++)
+    for (int i = 0; i < Rank; i++)
     {
-        if (a_shape[i] != output_shape[i] && a_shape[i] == 1)
-        {
-            bcast_a[i] = output_shape[i];
-        }
-        else
-        {
-            bcast_a[i] = 1;
-        }
-        if (b_shape[i] != output_shape[i] && b_shape[i] == 1)
-        {
-            bcast_b[i] = output_shape[i];
-        }
-        else
-        {
-            bcast_b[i] = 1;
-        }
+        bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1;
+        bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1;
     }
 
     return 0;
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
index 64c4412..c265077 100644
--- a/reference_model/src/ops/ewise_ternary.cc
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -47,10 +47,11 @@
     }
 
     // output and input must be the same types
-    if (inputs[0]->matchRank(*outputs[0]) || inputs[1]->matchRankType(*outputs[0]) ||
-        inputs[2]->matchRankType(*outputs[0]))
+    if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */) ||
+        inputs[1]->matchRankTypeShape(*outputs[0], true /* broadcastOk */) ||
+        inputs[2]->matchRankTypeShape(*outputs[0], true /* broadcastOk */))
     {
-        printNodeValidationError("Failure to match input and output rank and type");
+        printNodeValidationError("Failure to match input and output rank/type/shape");
         return 1;
     }
 
@@ -71,19 +72,16 @@
 template <int Rank, DType Dtype>
 int OpSelect<Rank, Dtype>::broadcast()
 {
-    std::vector<int> cond_shape = this->cond->getShape();
-    std::vector<int> then_shape = this->then_val->getShape();
-    std::vector<int> else_shape = this->else_val->getShape();
-    std::vector<int> out_shape  = this->out->getShape();
+    const std::vector<int>& cond_shape   = this->cond->getShape();
+    const std::vector<int>& then_shape   = this->then_val->getShape();
+    const std::vector<int>& else_shape   = this->else_val->getShape();
+    const std::vector<int>& output_shape = this->out->getShape();
 
     for (int i = 0; i < Rank; i++)
     {
-        this->bcast_cond[i] = (cond_shape[i] == 1) ? std::max(then_shape[i], else_shape[i]) : 1;
-        this->bcast_then[i] = (then_shape[i] == 1) ? std::max(cond_shape[i], else_shape[i]) : 1;
-        this->bcast_else[i] = (else_shape[i] == 1) ? std::max(then_shape[i], cond_shape[i]) : 1;
-        ERROR_IF((this->bcast_cond[i] * cond_shape[i]) != out_shape[i], "SELECT broadcast invariant failed");
-        ERROR_IF((this->bcast_then[i] * then_shape[i]) != out_shape[i], "SELECT broadcast invariant failed");
-        ERROR_IF((this->bcast_else[i] * else_shape[i]) != out_shape[i], "SELECT broadcast invariant failed");
+        this->bcast_cond[i] = (cond_shape[i] != output_shape[i] && cond_shape[i] == 1) ? output_shape[i] : 1;
+        this->bcast_then[i] = (then_shape[i] != output_shape[i] && then_shape[i] == 1) ? output_shape[i] : 1;
+        this->bcast_else[i] = (else_shape[i] != output_shape[i] && else_shape[i] == 1) ? output_shape[i] : 1;
     }
 
     return 0;
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 3fa23f9..5536583 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -148,6 +148,28 @@
         return 0;
     }
 
+    const int matchRankShape(const Tensor& ref, const bool broadcastOk = false) const
+    {
+        if (matchRank(ref))
+            return 1;
+
+        for (size_t i = 0; i < shape.size(); i++)
+        {
+            if (shape[i] != ref.shape[i])
+            {
+                if (!broadcastOk ||
+                    // For broadcasts, at least one operand must have size 1
+                    // if they don't both match
+                    (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1)))
+                {
+                    return 1;
+                }
+            }
+        }
+
+        return 0;
+    }
+
     // Sometimes we might want to match several semi-compatible types,
     // so just check rank and size here
     const int matchRankSize(const Tensor& ref) const