Change the shift of mul to tensor type

Right shift result on i32_t data type only, i.e. other data types
don't carry the shift operand.

In the spec, the shift type is a tensor in MT profile and is an
attribute in BI/MI profiles. Currently we treat the shift as tensor
throughout.

In implementation, since `ternaryExpr` is not implemented in Eigen,
decompose the original calculation into multiply and shift operation
seperately, and execute them via `binaryExpr`.

Change-Id: I349f4969545134ac5f13bc83032cd75cca3e7ba0
Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index b513f9a..ed176f3 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -441,9 +441,100 @@
 }
 
 template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
+int OpMul<Rank, InDtype, OutDtype>::eval()
+{
+    // All cases except in_out_t == int32_t go to the general binary op workflow.
+    if constexpr (InDtype != TOSA_REF_TYPE_INT32)
+    {
+        return BinaryNode<Rank, InDtype, OutDtype>::eval();
+    }
+    else
+    {
+        std::vector<int> calculated_shape;
+        this->broadcast(calculated_shape);
+
+        auto result_shape = this->result->getShape();
+        ERROR_IF(calculated_shape != result_shape,
+                 "Broadcast_shape failure, calculated_shape and result_shape don't match");
+
+        TIn ia = this->a->getTensor().broadcast(this->bcast_a);
+        TIn ib = this->b->getTensor().broadcast(this->bcast_b);
+
+        using TInt64      = Eigen::Tensor<int64_t, Rank>;
+        TInt64 tmp_result = ia.binaryExpr(ib, this->mul_fcn);
+
+        // Retrieve `shift` value and construct a Eigen tensor instance for it.
+        s = dynamic_cast<TosaReference::TensorTemplate<TShift>*>(this->inputs[2]);
+        ASSERT_MEM(s);
+
+        int shift = s->getTensor()(0);
+        TIn is(ia);
+        is.setConstant(shift);
+
+        TOut result               = tmp_result.binaryExpr(is, this->shr_fcn);
+        this->result->getTensor() = result;
+
+        return GraphNode::eval();
+    }
+}
+
+// Eigen operators requires tensor operands meet NumDims > 0, partial specialize
+// this like we did for the base class.
+template <>
+int OpMul<0, TOSA_REF_TYPE_INT32, TOSA_REF_TYPE_INT32>::eval()
+{
+    Eigen::Tensor<int64_t, 0> tmp_result = this->a->getTensor().binaryExpr(this->b->getTensor(), this->mul_fcn);
+
+    // Retrieve `shift` value.
+    s = dynamic_cast<TosaReference::TensorTemplate<TShift>*>(this->inputs[2]);
+    ASSERT_MEM(s);
+
+    Eigen::Tensor<int64_t, 0> shift;
+    shift.setConstant(s->getTensor()(0));
+
+    this->result->getTensor() = tmp_result.binaryExpr(shift, this->shr_fcn);
+
+    return GraphNode::eval();
+}
+
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
 int OpMul<Rank, InDtype, OutDtype>::register_fcn()
 {
-    int32_t shift = attribute->shift();
+    // Register evaluation function for in_out_t == int32_t case first as it supports shift
+    // right to int32_t output.
+    if constexpr (InDtype == TOSA_REF_TYPE_INT32)
+    {
+        // Perform multiplication on int32_t inputs to product int64_t result.
+        this->mul_fcn = [](InEigenType a, InEigenType b) -> int64_t {
+            int64_t result = static_cast<int64_t>(a) * static_cast<int64_t>(b);
+            return result;
+        };
+
+        // Convert data from int64_t to int32_t.
+        this->shr_fcn = [this](int64_t a, InEigenType shift) -> OutEigenType {
+            int64_t result;
+            if (shift > 0)
+            {
+                int64_t round = INT64_C(1) << (shift - 1);
+                result        = a + round;
+                result        = result >> shift;
+
+                REQUIRE(result >= QMin && result <= QMax,
+                        "OpMul: result %" PRId64 " exceeds valid range [%" PRId64 ", %" PRId64 "]", result, QMin, QMax);
+            }
+            else
+            {
+                result                = a;
+                int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
+                int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
+                REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
+                return static_cast<InEigenType>(result);
+            }
+            return static_cast<OutEigenType>(result);
+        };
+
+        return 0;
+    }
 
     switch (InDtype)
     {
@@ -455,31 +546,6 @@
         case TOSA_REF_TYPE_FP64:
             this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
             break;
-        case TOSA_REF_TYPE_INT32:
-            this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
-                int64_t result;
-                if (shift > 0)
-                {
-                    int64_t round = INT64_C(1) << (shift - 1);
-                    result        = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round;
-                    result        = result >> shift;
-
-                    REQUIRE(result >= QMin && result <= QMax,
-                            "OpMul: result %" PRId64 " exceeds valid range [%" PRId64 ", %" PRId64 "]", result, QMin,
-                            QMax);
-                }
-                else
-                {
-                    result                = static_cast<int64_t>(a) * b;
-                    int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
-                    int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
-                    REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
-                    return static_cast<InEigenType>(result);
-                }
-
-                return static_cast<OutEigenType>(result);
-            };
-            break;
         case TOSA_REF_TYPE_INT8:
         case TOSA_REF_TYPE_INT16:
             this->fcn = [](InEigenType lhs, InEigenType rhs) -> OutEigenType {
@@ -497,13 +563,6 @@
     return 0;
 }
 
-template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
-OpMul<Rank, InDtype, OutDtype>::~OpMul()
-{
-    if (attribute)
-        delete attribute;
-}
-
 template <int Rank, TOSA_REF_TYPE Dtype>
 int OpPow<Rank, Dtype>::register_fcn()
 {
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 1215c93..8d2e486 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -159,18 +159,33 @@
     OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
         : BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, id_)
     {
-        INIT_ATTRIBUTE(Mul);
+        if constexpr (InDtype == TOSA_REF_TYPE_INT32)
+        {
+            // Require `shift` operand.
+            this->setRequiredOperands(3, 1);
+        }
         register_fcn();
     }
-    virtual ~OpMul();
     static constexpr int64_t QMin = GetQMin<OutDtype>::value;
     static constexpr int64_t QMax = GetQMax<OutDtype>::value;
-    using InEigenType             = typename GetEigenType<InDtype>::type;
-    using OutEigenType            = typename GetEigenType<OutDtype>::type;
-    virtual int register_fcn();
+
+    using InEigenType    = typename GetEigenType<InDtype>::type;
+    using OutEigenType   = typename GetEigenType<OutDtype>::type;
+    using ShiftEigenType = typename GetEigenType<TOSA_REF_TYPE_INT8>::type;
+
+    using TIn    = Eigen::Tensor<InEigenType, Rank>;
+    using TOut   = Eigen::Tensor<OutEigenType, Rank>;
+    using TShift = Eigen::Tensor<ShiftEigenType, 0>;
+
+    int register_fcn();
+    int eval();
+
+    // Note that INT64 is not natively supported in Dtype system.
+    std::function<int64_t(InEigenType, InEigenType)> mul_fcn;
+    std::function<OutEigenType(int64_t, InEigenType)> shr_fcn;
 
 protected:
-    TosaMulAttribute* attribute;
+    TosaReference::TensorTemplate<TShift>* s;
 };
 
 template <int Rank, TOSA_REF_TYPE InDtype>