Add support for precise mode in eager runner

Add support for Fp64 tensors in the eager runner's
helper functions, when precise mode is enabled.

Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli@arm.com>
Change-Id: Ib737c0d18fb1c7ac40ce6ea03a4fbcefae88ba5c
diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc
index 311db7c..bf23bac 100644
--- a/reference_model/src/model_runner_impl.cc
+++ b/reference_model/src/model_runner_impl.cc
@@ -1,5 +1,5 @@
 
-// Copyright (c) 2022-2023, ARM Limited.
+// Copyright (c) 2022-2024, ARM Limited.
 //
 //    Licensed under the Apache License, Version 2.0 (the "License");
 //    you may not use this file except in compliance with the License.
@@ -166,6 +166,40 @@
     return 0;
 }
 
+int ModelRunnerImpl::setInputForPrecMode(Tensor* tensor, std::string input_name, uint8_t* raw_ptr, size_t size)
+{
+    ASSERT_MSG(tensor, "Tensor not provided!");
+    if (!g_func_config.precise_mode)
+    {
+        WARNING("Cannot set input tensor %s using precise mode setters when not running in precise mode!",
+                input_name.c_str());
+        return 1;
+    }
+
+    DType ser_dtype = tensor->getSerializationDtype();
+    int status;
+
+    switch (ser_dtype)
+    {
+        case DType::DType_FP16: {
+            auto typed_ptr     = reinterpret_cast<half_float::half*>(raw_ptr);
+            const int elements = size / sizeof(half_float::half);
+            status             = setInput(input_name, ArrayProxy(elements, typed_ptr));
+            break;
+        }
+        case DType::DType_FP32: {
+            auto typed_ptr     = reinterpret_cast<float*>(raw_ptr);
+            const int elements = size / sizeof(float);
+            status             = setInput(input_name, ArrayProxy(elements, typed_ptr));
+            break;
+        }
+        default:
+            status = 1;
+    }
+
+    return status;
+}
+
 int ModelRunnerImpl::setInput(std::string input_name, uint8_t* raw_ptr, size_t size)
 {
     if (_main_gt == nullptr)
@@ -197,6 +231,18 @@
             status             = setInput(input_name, ArrayProxy(elements, typed_ptr));
             break;
         }
+        case TOSA_REF_TYPE_FP64:
+            if (g_func_config.precise_mode)
+            {
+                status = setInputForPrecMode(tensor, input_name, raw_ptr, size);
+            }
+            else
+            {
+                auto typed_ptr     = reinterpret_cast<double*>(raw_ptr);
+                const int elements = size / sizeof(double);
+                status             = setInput(input_name, ArrayProxy(elements, typed_ptr));
+            }
+            break;
         case TOSA_REF_TYPE_INT16: {
             auto typed_ptr     = reinterpret_cast<int16_t*>(raw_ptr);
             const int elements = size / sizeof(int16_t);
@@ -281,6 +327,12 @@
             status             = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
             break;
         }
+        case TOSA_REF_TYPE_FP64: {
+            auto typed_ptr     = reinterpret_cast<double*>(raw_ptr);
+            const int elements = size / sizeof(double);
+            status             = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
+            break;
+        }
         case TOSA_REF_TYPE_BOOL: {
             auto typed_ptr     = reinterpret_cast<unsigned char*>(raw_ptr);
             const int elements = size / sizeof(unsigned char);
@@ -394,12 +446,14 @@
 }
 
 // Template explicit specialization
+template int ModelRunnerImpl::setInput<double>(std::string input_name, ArrayProxy<double> vals);
 template int ModelRunnerImpl::setInput<float>(std::string input_name, ArrayProxy<float> vals);
 template int ModelRunnerImpl::setInput<half_float::half>(std::string input_name, ArrayProxy<half_float::half> vals);
 template int ModelRunnerImpl::setInput<int32_t>(std::string input_name, ArrayProxy<int32_t> vals);
 template int ModelRunnerImpl::setInput<int64_t>(std::string input_name, ArrayProxy<int64_t> vals);
 template int ModelRunnerImpl::setInput<unsigned char>(std::string input_name, ArrayProxy<unsigned char> vals);
 
+template std::vector<double> ModelRunnerImpl::getOutput<double>(std::string output_name);
 template std::vector<float> ModelRunnerImpl::getOutput<float>(std::string output_name);
 template std::vector<half_float::half> ModelRunnerImpl::getOutput<half_float::half>(std::string output_name);
 template std::vector<int32_t> ModelRunnerImpl::getOutput<int32_t>(std::string output_name);
diff --git a/reference_model/src/model_runner_impl.h b/reference_model/src/model_runner_impl.h
index aed8a1e..db9755c 100644
--- a/reference_model/src/model_runner_impl.h
+++ b/reference_model/src/model_runner_impl.h
@@ -1,5 +1,5 @@
 
-// Copyright (c) 2022-2023, ARM Limited.
+// Copyright (c) 2022-2024, ARM Limited.
 //
 //    Licensed under the Apache License, Version 2.0 (the "License");
 //    you may not use this file except in compliance with the License.
@@ -64,6 +64,7 @@
     GraphStatus initialize(TosaSerializationBasicBlock* bb, TosaSerializationHandler* serialization_handler);
     void validateTosaVersion(TosaSerializationHandler& serialization_handler);
     void checkGraphStatus(SubgraphTraverser& main_gt);
+    int setInputForPrecMode(Tensor* tensor, std::string input_name, uint8_t* raw_ptr, size_t size);
 };
 
 };    // namespace TosaReference
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 645b55f..e84507b 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -1,5 +1,5 @@
 
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
 //
 //    Licensed under the Apache License, Version 2.0 (the "License");
 //    you may not use this file except in compliance with the License.
@@ -580,6 +580,14 @@
     uint32_t elements = getElementCount();
     switch (getDtype())
     {
+        case TOSA_REF_TYPE_FP64:
+            if (!g_func_config.precise_mode)
+            {
+                WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).",
+                        EnumNameTOSAREFTYPE(getDtype()));
+                return -2;
+            }
+            // continue with setting float vals in the tensor
         case TOSA_REF_TYPE_FP16:
         case TOSA_REF_TYPE_FP32:
             if (vals.size() != elements)
@@ -622,6 +630,14 @@
 
     switch (getDtype())
     {
+        case TOSA_REF_TYPE_FP64:
+            if (!g_func_config.precise_mode)
+            {
+                WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).",
+                        EnumNameTOSAREFTYPE(getDtype()));
+                return -2;
+            }
+            // continue with setting float vals in the tensor
         case TOSA_REF_TYPE_FP16:
             if (vals.size() != elements)
             {
@@ -953,7 +969,7 @@
 template <class T>
 int TosaReference::TensorTemplate<T>::setTensorValueDouble(const size_t buflen, const double* vals)
 {
-    FATAL_ERROR("TensorTemplate<T>::setTensorValueFloat should not be called.  "
+    FATAL_ERROR("TensorTemplate<T>::setTensorValueDouble should not be called.  "
                 "Implement template specialization version.");
     return 0;
 }
@@ -1254,6 +1270,150 @@
     return 0;
 }
 
+template <>
+int TosaReference::Tensor0<double>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+    ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+    (*tensor)(0) = vals[0];
+
+    return 0;
+}
+
+template <>
+int TosaReference::Tensor1<double>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+    uint32_t idx = 0;
+
+    ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+    for (int i0 = 0; i0 < shape[0]; i0++)
+    {
+        (*tensor)(i0) = vals[idx++];
+    }
+
+    return 0;
+}
+
+template <>
+int TosaReference::Tensor2<double>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+    uint32_t idx = 0;
+
+    ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+    for (int i0 = 0; i0 < shape[0]; i0++)
+    {
+        for (int i1 = 0; i1 < shape[1]; i1++)
+        {
+            (*tensor)(i0, i1) = vals[idx++];
+        }
+    }
+
+    return 0;
+}
+
+template <>
+int TosaReference::Tensor3<double>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+    uint32_t idx = 0;
+
+    ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+    for (int i0 = 0; i0 < shape[0]; i0++)
+    {
+        for (int i1 = 0; i1 < shape[1]; i1++)
+        {
+            for (int i2 = 0; i2 < shape[2]; i2++)
+            {
+                (*tensor)(i0, i1, i2) = vals[idx++];
+            }
+        }
+    }
+
+    return 0;
+}
+
+template <>
+int TosaReference::Tensor4<double>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+    uint32_t idx = 0;
+
+    ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+    for (int i0 = 0; i0 < shape[0]; i0++)
+    {
+        for (int i1 = 0; i1 < shape[1]; i1++)
+        {
+            for (int i2 = 0; i2 < shape[2]; i2++)
+            {
+                for (int i3 = 0; i3 < shape[3]; i3++)
+                {
+                    (*tensor)(i0, i1, i2, i3) = vals[idx++];
+                }
+            }
+        }
+    }
+
+    return 0;
+}
+
+template <>
+int TosaReference::Tensor5<double>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+    uint32_t idx = 0;
+
+    ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+    for (int i0 = 0; i0 < shape[0]; i0++)
+    {
+        for (int i1 = 0; i1 < shape[1]; i1++)
+        {
+            for (int i2 = 0; i2 < shape[2]; i2++)
+            {
+                for (int i3 = 0; i3 < shape[3]; i3++)
+                {
+                    for (int i4 = 0; i4 < shape[4]; i4++)
+                    {
+                        (*tensor)(i0, i1, i2, i3, i4) = vals[idx++];
+                    }
+                }
+            }
+        }
+    }
+
+    return 0;
+}
+
+template <>
+int TosaReference::Tensor6<double>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+    uint32_t idx = 0;
+
+    ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+    for (int i0 = 0; i0 < shape[0]; i0++)
+    {
+        for (int i1 = 0; i1 < shape[1]; i1++)
+        {
+            for (int i2 = 0; i2 < shape[2]; i2++)
+            {
+                for (int i3 = 0; i3 < shape[3]; i3++)
+                {
+                    for (int i4 = 0; i4 < shape[4]; i4++)
+                    {
+                        for (int i5 = 0; i5 < shape[5]; i5++)
+                        {
+                            (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++];
+                        }
+                    }
+                }
+            }
+        }
+    }
+    return 0;
+}
+
 template <class T>
 int TosaReference::TensorTemplate<T>::setTensorValueInt16(const size_t bufLen, const int16_t* vals)
 {