Support 8-bit TABLE op.

Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: If577035d71c5f9970df5b6a78640a3028c3f83c0
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 76cebeb..3379ffe 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -519,20 +519,20 @@
     return 0;
 }
 
-template <int Rank>
-OpTable<Rank>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+template <int Rank, DType InDtype>
+OpTable<Rank, InDtype>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
     : GraphNode(Op_TABLE, id_)
 {
     setRequiredOperands(2, 1);
     setRequiredRank(0, 6);
 }
 
-template <int Rank>
-OpTable<Rank>::~OpTable()
+template <int Rank, DType InDtype>
+OpTable<Rank, InDtype>::~OpTable()
 {}
 
-template <int Rank>
-int OpTable<Rank>::checkTensorAttributes()
+template <int Rank, DType InDtype>
+int OpTable<Rank, InDtype>::checkTensorAttributes()
 {
     if (validateRequiredOperands())
         return 1;
@@ -542,12 +542,29 @@
         return 1;
     }
 
-    if (inputs[1]->getRank() != 1 || inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
+    if (inputs[1]->getRank() != 1)
     {
-        FATAL_ERROR_NODE("OpTable: must have INT16 table with 513 entries");
+        printNodeValidationError("OpTable: Table must be rank 1 tensor");
         return 1;
     }
 
+    if (inputs[0]->getDtype() == DType_INT8)
+    {
+        if (inputs[1]->getElementCount() != 256 || inputs[1]->getDtype() != DType_INT8)
+        {
+            printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8");
+            return 1;
+        }
+    }
+    else if (inputs[0]->getDtype() == DType_INT16)
+    {
+        if (inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
+        {
+            printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16");
+            return 1;
+        }
+    }
+
     in    = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
     table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
     out   = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
@@ -557,25 +574,41 @@
     return 0;
 }
 
-template <int Rank>
-int OpTable<Rank>::eval()
+template <int Rank, DType InDtype>
+int OpTable<Rank, InDtype>::eval()
 {
-    this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
-        // 1. make sure input is int16 range
-        int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
+    switch (InDtype)
+    {
+        case DType_INT8:
+            this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
+                int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
+                int32_t index           = input_truncated - QInMin;
+                int32_t value           = this->table->getTensor()(index);
 
-        // 2. calculate index and interpolation fraction
-        int32_t index = (input_truncated >> 7) + (1 << (IntegerBits - 1));
-        index         = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1);    // 9-bit index
-        int32_t frac  = (input_truncated)&0x7F;                                                 // 7-bit fraction
+                return value;
+            });
+            break;
+        case DType_INT16:
+            this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
+                // 1. make sure input is int16 range
+                int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
 
-        // 3. interpolate, generate 16.7 (23-bit) output
-        int32_t base  = this->table->getTensor()(index);
-        int32_t next  = this->table->getTensor()(index + 1);
-        int32_t value = (base << 7) + (next - base) * frac;
+                // 2. calculate index and interpolation fraction
+                int32_t index = (input_truncated >> FractionBits) + (1 << (IntegerBits - 1));
+                index         = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1);    // 9-bit index
+                int32_t frac  = (input_truncated)&0x7F;    // 7-bit fraction
 
-        return value;
-    });
+                // 3. interpolate, generate 16.7 (23-bit) output
+                int32_t base  = this->table->getTensor()(index);
+                int32_t next  = this->table->getTensor()(index + 1);
+                int32_t value = (base << 7) + (next - base) * frac;
+
+                return value;
+            });
+            break;
+        default:
+            FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+    }
 
     return GraphNode::eval();
 }
@@ -632,7 +665,8 @@
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
 
-DEF_INSTANTIATE_ONE_RANK_0_6(OpTable);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
 
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);