Initial checkin of TOSA reference_model and tests

Change-Id: I2f8e7fa63e2ae40203e57d2cc8814bde3b312cb6
Signed-off-by: Eric Kunze <eric.kunze@arm.com>
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
new file mode 100644
index 0000000..4d4f8b9
--- /dev/null
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -0,0 +1,586 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//         http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+
+#include "ewise_binary.h"
+#include "arith_util.h"
+#include "quant_util.h"
+#include "template_types.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType InDtype, DType OutDtype>
+BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(const Op& op_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+    : GraphNode(op_, id_)
+{
+    setRequiredOperands(2, 1);
+    setRequiredRank(0, 6);
+
+    a_rank = b_rank = max_input_rank = -1;
+    a = b   = nullptr;
+    a_rank0 = b_rank0 = nullptr;
+    result            = nullptr;
+
+    fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
+{}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
+{
+    if (validateRequiredOperands())
+        return 1;
+
+    if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+    {
+        return 1;
+    }
+
+    a_rank = inputs[0]->getRank();
+    b_rank = inputs[1]->getRank();
+    if (a_rank != 0 && b_rank != 0 && a_rank != b_rank)
+    {
+        printNodeValidationError("Binary operator input ranks must match");
+        return 1;
+    }
+
+    max_input_rank = a_rank >= b_rank ? a_rank : b_rank;
+
+    // A & B must be the same types
+    if (inputs[0]->matchType(*inputs[1]))
+    {
+        printNodeValidationError("Binary operator input types must match");
+        return 1;
+    }
+
+    // Result's geometry must match, but the type may be wider
+    if (outputs[0]->getRank() != max_input_rank)
+    {
+        printNodeValidationError("Binary operator input and output genometry must match");
+        return 1;
+    }
+
+    if (a_rank == max_input_rank)
+    {
+        a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+    }
+    else
+    {
+        a_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[0]);
+    }
+
+    if (b_rank == max_input_rank)
+    {
+        b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+    }
+    else
+    {
+        b_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[1]);
+    }
+
+    result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+    // either a or b can be rank0
+    // a_rank0 and b_rank0 can't be valid at the same time.
+    // if a and be are both rank0, they should be evaulated as 'a' and 'b', instead of 'a_rank0' and 'b_rank0'
+    ASSERT_MEM((a || a_rank0) && (b || b_rank0) && !(a_rank0 && b_rank0) && result);
+
+    return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
+{
+    auto output_shape = result->getTensor().dimensions();
+
+    std::vector<int> a_shape, b_shape;
+
+    if (a_rank == max_input_rank)
+    {
+        a_shape = a->getShape();
+    }
+    else
+    {
+        a_shape.assign(max_input_rank, 1);
+    }
+
+    if (b_rank == max_input_rank)
+    {
+        b_shape = b->getShape();
+    }
+    else
+    {
+        b_shape.assign(max_input_rank, 1);
+    }
+
+    for (int i = 0; i < max_input_rank; i++)
+    {
+        if (a_shape[i] != output_shape[i] && a_shape[i] == 1)
+        {
+            bcast_a[i] = output_shape[i];
+        }
+        else
+        {
+            bcast_a[i] = 1;
+        }
+        if (b_shape[i] != output_shape[i] && b_shape[i] == 1)
+        {
+            bcast_b[i] = output_shape[i];
+        }
+        else
+        {
+            bcast_b[i] = 1;
+        }
+    }
+
+    return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int BinaryNode<Rank, InDtype, OutDtype>::eval()
+{
+    this->broadcast();
+
+    Eigen::array<int, Rank> reshaper;
+    reshaper.fill(1);
+    TIn ia, ib;
+
+    if (this->a_rank == this->max_input_rank)
+    {
+        ia = this->a->getTensor().broadcast(this->bcast_a);
+    }
+    else
+    {
+        ia = this->a_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_a);
+    }
+
+    if (this->b_rank == this->max_input_rank)
+    {
+        ib = this->b->getTensor().broadcast(this->bcast_b);
+    }
+    else
+    {
+        ib = this->b_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_b);
+    }
+
+    this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
+
+    return GraphNode::eval();
+}
+
+// still need to partial specialize this, or Eigen will throw static assertion
+template <DType InDtype, DType OutDtype>
+int BinaryNode<0, InDtype, OutDtype>::eval()
+{
+    this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
+
+    return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpAdd<Rank, Dtype>::register_fcn()
+{
+    switch (InDtype)
+    {
+        case DType_FLOAT:
+        case DType_INT32:
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+    }
+
+    return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
+{
+    int32_t num_bits = 0;
+    switch (Dtype)
+    {
+        case DType_INT8:
+            num_bits = 8;
+            break;
+        case DType_INT16:
+            num_bits = 16;
+            break;
+        case DType_INT32:
+            num_bits = 32;
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+    }
+
+    this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
+        uint32_t sign      = a & (1 << (num_bits - 1));
+        uint32_t ones_mask = ONES_MASK(b) << (num_bits - b);
+        if (sign)
+            return ones_mask | (a >> b);
+        else
+            return (~ones_mask) & (a >> b);
+    };
+
+    return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpBitwiseAnd<Rank, Dtype>::register_fcn()
+{
+    switch (Dtype)
+    {
+        case DType_AINT8:
+        case DType_INT16:
+        case DType_INT32:
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+    }
+
+    return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpBitwiseOr<Rank, Dtype>::register_fcn()
+{
+    switch (Dtype)
+    {
+        case DType_AINT8:
+        case DType_INT16:
+        case DType_INT32:
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+    }
+
+    return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpBitwiseXor<Rank, Dtype>::register_fcn()
+{
+    switch (Dtype)
+    {
+        case DType_AINT8:
+        case DType_INT16:
+        case DType_INT32:
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+    }
+
+    return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalAnd<Rank, Dtype>::register_fcn()
+{
+    switch (Dtype)
+    {
+        case DType_BOOL:
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+    }
+
+    return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
+{
+    switch (Dtype)
+    {
+        case DType_INT8:
+        case DType_INT16:
+        case DType_INT32:
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; };
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+    }
+
+    return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalRightShift<Rank, Dtype>::register_fcn()
+{
+    int32_t num_bits = 0;
+    switch (Dtype)
+    {
+        case DType_INT8:
+            num_bits = 8;
+            break;
+        case DType_INT16:
+            num_bits = 16;
+            break;
+        case DType_INT32:
+            num_bits = 32;
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+    }
+
+    this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
+        uint32_t mask = ONES_MASK(num_bits) >> b;
+        return (a >> b) & mask;
+    };
+
+    return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalOr<Rank, Dtype>::register_fcn()
+{
+    switch (Dtype)
+    {
+        case DType_BOOL:
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+    }
+
+    return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalXor<Rank, Dtype>::register_fcn()
+{
+    switch (Dtype)
+    {
+        case DType_BOOL:
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+    }
+
+    return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpMaximum<Rank, Dtype>::register_fcn()
+{
+    switch (Dtype)
+    {
+        case DType_FLOAT:
+        case DType_INT32:
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+    }
+
+    return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpMinimum<Rank, Dtype>::register_fcn()
+{
+    switch (Dtype)
+    {
+        case DType_FLOAT:
+        case DType_INT32:
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+    }
+
+    return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int OpMul<Rank, InDtype, OutDtype>::register_fcn()
+{
+    switch (InDtype)
+    {
+        case DType_FLOAT:
+        case DType_INT32:
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
+            break;
+        case DType_INT8:
+        case DType_INT16:
+            this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
+                OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
+
+                OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
+
+                return clamped_output;
+            };
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+    }
+
+    return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpPow<Rank, Dtype>::register_fcn()
+{
+    switch (Dtype)
+    {
+        case DType_FLOAT:
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+    }
+
+    return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpSub<Rank, Dtype>::register_fcn()
+{
+    switch (InDtype)
+    {
+        case DType_FLOAT:
+        case DType_INT32:
+            this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+    }
+
+    return 0;
+}
+
+template <int Rank>
+OpTable<Rank>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+    : GraphNode(Op_TABLE, id_)
+{
+    setRequiredOperands(2, 1);
+    setRequiredRank(0, 6);
+}
+
+template <int Rank>
+OpTable<Rank>::~OpTable()
+{}
+
+template <int Rank>
+int OpTable<Rank>::checkTensorAttributes()
+{
+    if (validateRequiredOperands())
+        return 1;
+
+    if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+    {
+        return 1;
+    }
+
+    if (inputs[1]->getRank() != 1 || inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
+    {
+        FATAL_ERROR_NODE("OpTable: must have INT16 table with 513 entries");
+        return 1;
+    }
+
+    in    = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+    table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
+    out   = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+    ASSERT_MEM(in && table && out);
+
+    return 0;
+}
+
+template <int Rank>
+int OpTable<Rank>::eval()
+{
+    this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
+        // 1. make sure input is int16 range
+        int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
+
+        // 2. calculate index and interpolation fraction
+        int32_t index = (input_truncated >> 7) + (1 << (IntegerBits - 1));
+        index         = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1);    // 9-bit index
+        int32_t frac  = (input_truncated)&0x7F;                                                 // 7-bit fraction
+
+        // 3. interpolate, generate 16.7 (23-bit) output
+        int32_t base  = this->table->getTensor()(index);
+        int32_t next  = this->table->getTensor()(index + 1);
+        int32_t value = (base << 7) + (next - base) * frac;
+
+        return value;
+    });
+
+    return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
+
+DEF_INSTANTIATE_ONE_RANK_0_6(OpTable);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);