Fix broadcast bug
- test like [1] + [2] = [1] should be treated as invalid test
- modify matchRankShape() function so it allows size 1 only on the source tensor but not target tensor
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I6bbb6a63dc1143712e7eef736a991cac419b009e
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 5536583..d857dc8 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -136,9 +136,11 @@
if (shape[i] != ref.shape[i])
{
if (!broadcastOk ||
- // For broadcasts, at least one operand must have size 1
- // if they don't both match
- (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1)))
+ // For broadcasts, the order of *this and ref matters.
+ // *this should be the source tensor.
+ // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
+ // this->shape must have size 1 if they don't match
+ (broadcastOk && (shape[i] != 1)))
{
return 1;
}
@@ -158,9 +160,11 @@
if (shape[i] != ref.shape[i])
{
if (!broadcastOk ||
- // For broadcasts, at least one operand must have size 1
- // if they don't both match
- (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1)))
+ // For broadcasts, the order of *this and ref matters.
+ // *this should be the source tensor.
+ // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
+ // this->shape must have size 1 if they don't match
+ (broadcastOk && (shape[i] != 1)))
{
return 1;
}