Support 8-bit TABLE op.

Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: If577035d71c5f9970df5b6a78640a3028c3f83c0
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
index eee5464..bf80859 100644
--- a/reference_model/src/graph_node.h
+++ b/reference_model/src/graph_node.h
@@ -33,15 +33,6 @@
 #define DEF_INSTANTIATE_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2)                                            \
     template class TosaReference::OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>;
 
-#define DEF_INSTANTIATE_ONE_RANK_0_6(OP)                                                                               \
-    template class TosaReference::OP<0>;                                                                               \
-    template class TosaReference::OP<1>;                                                                               \
-    template class TosaReference::OP<2>;                                                                               \
-    template class TosaReference::OP<3>;                                                                               \
-    template class TosaReference::OP<4>;                                                                               \
-    template class TosaReference::OP<5>;                                                                               \
-    template class TosaReference::OP<6>;
-
 #define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP<DType_##DTYPE>;
 
 #define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2>;
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 76cebeb..3379ffe 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -519,20 +519,20 @@
     return 0;
 }
 
-template <int Rank>
-OpTable<Rank>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+template <int Rank, DType InDtype>
+OpTable<Rank, InDtype>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
     : GraphNode(Op_TABLE, id_)
 {
     setRequiredOperands(2, 1);
     setRequiredRank(0, 6);
 }
 
-template <int Rank>
-OpTable<Rank>::~OpTable()
+template <int Rank, DType InDtype>
+OpTable<Rank, InDtype>::~OpTable()
 {}
 
-template <int Rank>
-int OpTable<Rank>::checkTensorAttributes()
+template <int Rank, DType InDtype>
+int OpTable<Rank, InDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
         return 1;
@@ -542,12 +542,29 @@
         return 1;
     }
 
-    if (inputs[1]->getRank() != 1 || inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
+    if (inputs[1]->getRank() != 1)
     {
-        FATAL_ERROR_NODE("OpTable: must have INT16 table with 513 entries");
+        printNodeValidationError("OpTable: Table must be rank 1 tensor");
         return 1;
     }
 
+    if (inputs[0]->getDtype() == DType_INT8)
+    {
+        if (inputs[1]->getElementCount() != 256 || inputs[1]->getDtype() != DType_INT8)
+        {
+            printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8");
+            return 1;
+        }
+    }
+    else if (inputs[0]->getDtype() == DType_INT16)
+    {
+        if (inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
+        {
+            printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16");
+            return 1;
+        }
+    }
+
     in    = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
     out   = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
@@ -557,25 +574,41 @@
     return 0;
 }
 
-template <int Rank>
-int OpTable<Rank>::eval()
+template <int Rank, DType InDtype>
+int OpTable<Rank, InDtype>::eval()
 {
-    this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
-        // 1. make sure input is int16 range
-        int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
+    switch (InDtype)
+    {
+        case DType_INT8:
+            this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
+                int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
+                int32_t index           = input_truncated - QInMin;
+                int32_t value           = this->table->getTensor()(index);
 
-        // 2. calculate index and interpolation fraction
-        int32_t index = (input_truncated >> 7) + (1 << (IntegerBits - 1));
-        index         = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1);    // 9-bit index
-        int32_t frac  = (input_truncated)&0x7F;                                                 // 7-bit fraction
+                return value;
+            });
+            break;
+        case DType_INT16:
+            this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
+                // 1. make sure input is int16 range
+                int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
 
-        // 3. interpolate, generate 16.7 (23-bit) output
-        int32_t base  = this->table->getTensor()(index);
-        int32_t next  = this->table->getTensor()(index + 1);
-        int32_t value = (base << 7) + (next - base) * frac;
+                // 2. calculate index and interpolation fraction
+                int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1));
+                index         = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1);    // 9-bit index
+                int32_t frac  = (input_truncated)&0x7F;    // 7-bit fraction
 
-        return value;
-    });
+                // 3. interpolate, generate 16.7 (23-bit) output
+                int32_t base  = this->table->getTensor()(index);
+                int32_t next  = this->table->getTensor()(index + 1);
+                int32_t value = (base << 7) + (next - base) * frac;
+
+                return value;
+            });
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+    }
 
     return GraphNode::eval();
 }
@@ -632,7 +665,8 @@
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
 
-DEF_INSTANTIATE_ONE_RANK_0_6(OpTable);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 6b9c98d..a5b1059 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -176,7 +176,7 @@
     TosaMulAttribute* attribute;
 };
 
-template <int Rank>
+template <int Rank, DType InDtype>
 class OpTable : public GraphNode
 {
 public:
@@ -186,9 +186,8 @@
     virtual int checkTensorAttributes();
     virtual int eval();
 
-    static constexpr DType InDtype           = DType_INT16;
-    static constexpr DType TableDtype        = DType_INT16;
-    static constexpr DType OutDtype          = DType_INT32;
+    static constexpr DType TableDtype        = (InDtype == DType_INT8) ? DType_INT8 : DType_INT16;
+    static constexpr DType OutDtype          = (InDtype == DType_INT8) ? DType_INT8 : DType_INT32;
     using InEigenType                        = typename GetEigenType<InDtype>::type;
     using TableEigenType                     = typename GetEigenType<TableDtype>::type;
     using OutEigenType                       = typename GetEigenType<OutDtype>::type;
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 440d624..726ab7c 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -178,7 +178,8 @@
             DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
             break;
         case Op_TABLE:
-            DEF_FACTORY_ONE_RANK_0_6(OpTable);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
+            DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
             break;
 
         // ewise_unary