Reference model changes for fp16 support

Change-Id: I72f21fcfa153046274969d327313e3349981dbe6
Signed-off-by: James Ward <james.ward@arm.com>
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 52de2e4..50e710a 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -1,5 +1,5 @@
 
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
 //
 //    Licensed under the Apache License, Version 2.0 (the "License");
 //    you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
 #include "quant_util.h"
 #include "template_types.h"
 #include <cmath>
+#include "half.hpp"
 
 using namespace TosaReference;
 using namespace Eigen;
@@ -287,6 +288,30 @@
 }
 
 template <DType InDtype>
+CastHelper<InDtype, DType_FP16>::CastHelper()
+{
+    fcn = [](InEigenType in) -> float {
+        half_float::half out = half_float::half_cast<half_float::half, InEigenType>(in);  // Cast to half_float
+        return half_float::half_cast<float, half_float::half>(out);  // Cast to float (underlying FP16 EigenType)
+    };
+}
+
+template <DType OutDtype>
+CastHelper<DType_FP16, OutDtype>::CastHelper()
+{
+    // Assuming InEigenType = float.
+    fcn = [](float in) -> OutEigenType {
+        // Perform initial rounding in half-precision then cast back to float
+        half_float::half h = half_float::half_cast<half_float::half, float>(in);
+        h = std::round(h);
+        OutEigenType out = half_float::half_cast<float, half_float::half>(h);
+        out              = std::max<OutEigenType>(out, OutMin);
+        out              = std::min<OutEigenType>(out, OutMax);
+        return out;
+    };
+}
+
+template <DType InDtype>
 CastHelper<InDtype, DType_FLOAT>::CastHelper()
 {
     fcn = [](InEigenType in) -> float {
@@ -313,15 +338,21 @@
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16);
 DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32);