Reference model changes for fp16 support

Change-Id: I72f21fcfa153046274969d327313e3349981dbe6
Signed-off-by: James Ward <james.ward@arm.com>
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index 24eadeb..fd6dd25 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -1,5 +1,5 @@
 
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
 //
 //    Licensed under the Apache License, Version 2.0 (the "License");
 //    you may not use this file except in compliance with the License.
@@ -45,7 +45,7 @@
     TosaReference::TensorTemplate<TOut>* output;
 };
 
-template <DType Dtype>
+template <DType Dtype, DType AccDtype>
 class OpAvgPool2d : public GraphNode
 {
 public:
@@ -55,9 +55,8 @@
     virtual int checkTensorAttributes();
     virtual int eval();
 
-    static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
     using InEigenType               = typename GetEigenType<Dtype>::type;
-    using AccEigenType              = typename GetEigenType<AccDtype>::type;
+    using AccEigenType              = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
     using OutEigenType              = typename GetEigenType<Dtype>::type;
     using TIn                       = Eigen::Tensor<InEigenType, 4>;
     using TOut                      = Eigen::Tensor<OutEigenType, 4>;
@@ -75,7 +74,7 @@
     ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right);
 };
 
-template <DType InDtype, DType WeightDtype>
+template <DType InDtype, DType WeightDtype, DType AccDtype>
 class OpConv2d : public GraphNode
 {
 public:
@@ -85,15 +84,14 @@
     virtual int checkTensorAttributes() final;
     virtual int eval() final;
 
-    static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
-
     using InEigenType     = typename GetEigenType<InDtype>::type;
     using WeightEigenType = typename GetEigenType<WeightDtype>::type;
-    using AccEigenType    = typename GetEigenType<AccDtype>::type;
+    using AccEigenType    = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
+    using OutEigenType    = typename GetEigenType<AccDtype>::type;
     using TIn             = Eigen::Tensor<InEigenType, 4>;
     using TWeight         = Eigen::Tensor<WeightEigenType, 4>;
-    using TBias           = Eigen::Tensor<AccEigenType, 1>;
-    using TAcc            = Eigen::Tensor<AccEigenType, 4>;
+    using TBias           = Eigen::Tensor<OutEigenType, 1>;
+    using TOut            = Eigen::Tensor<OutEigenType, 4>;
 
     static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
     static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
@@ -102,11 +100,11 @@
     TosaReference::TensorTemplate<TIn>* input;
     TosaReference::TensorTemplate<TWeight>* weight;
     TosaReference::TensorTemplate<TBias>* bias;
-    TosaReference::TensorTemplate<TAcc>* output;
+    TosaReference::TensorTemplate<TOut>* output;
     tosa::TosaConvAttribute* attribute;
 };
 
-template <DType InDtype, DType WeightDtype>
+template <DType InDtype, DType WeightDtype, DType AccDtype>
 class OpConv3d : public GraphNode
 {
 public:
@@ -116,15 +114,14 @@
     virtual int checkTensorAttributes() final;
     virtual int eval() final;
 
-    static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
-
     using InEigenType     = typename GetEigenType<InDtype>::type;
     using WeightEigenType = typename GetEigenType<WeightDtype>::type;
-    using AccEigenType    = typename GetEigenType<AccDtype>::type;
+    using AccEigenType    = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
+    using OutEigenType    = typename GetEigenType<AccDtype>::type;
     using TIn             = Eigen::Tensor<InEigenType, 5>;
     using TWeight         = Eigen::Tensor<WeightEigenType, 5>;
-    using TBias           = Eigen::Tensor<AccEigenType, 1>;
-    using TAcc            = Eigen::Tensor<AccEigenType, 5>;
+    using TBias           = Eigen::Tensor<OutEigenType, 1>;
+    using TOut            = Eigen::Tensor<OutEigenType, 5>;
 
     static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
     static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
@@ -133,11 +130,11 @@
     TosaReference::TensorTemplate<TIn>* input;
     TosaReference::TensorTemplate<TWeight>* weight;
     TosaReference::TensorTemplate<TBias>* bias;
-    TosaReference::TensorTemplate<TAcc>* output;
+    TosaReference::TensorTemplate<TOut>* output;
     tosa::TosaConvAttribute* attribute;
 };
 
-template <DType InDtype, DType WeightDtype>
+template <DType InDtype, DType WeightDtype, DType AccDtype>
 class OpDepthwiseConv2d : public GraphNode
 {
 public:
@@ -147,15 +144,14 @@
     virtual int checkTensorAttributes() final;
     virtual int eval() final;
 
-    static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
-
     using InEigenType     = typename GetEigenType<InDtype>::type;
     using WeightEigenType = typename GetEigenType<WeightDtype>::type;
-    using AccEigenType    = typename GetEigenType<AccDtype>::type;
+    using AccEigenType    = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
+    using OutEigenType    = typename GetEigenType<AccDtype>::type;
     using TIn             = Eigen::Tensor<InEigenType, 4>;
     using TWeight         = Eigen::Tensor<WeightEigenType, 4>;
-    using TBias           = Eigen::Tensor<AccEigenType, 1>;
-    using TAcc            = Eigen::Tensor<AccEigenType, 4>;
+    using TBias           = Eigen::Tensor<OutEigenType, 1>;
+    using TOut            = Eigen::Tensor<OutEigenType, 4>;
 
     static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
     static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
@@ -164,11 +160,11 @@
     TosaReference::TensorTemplate<TIn>* input;
     TosaReference::TensorTemplate<TWeight>* weight;
     TosaReference::TensorTemplate<TBias>* bias;
-    TosaReference::TensorTemplate<TAcc>* output;
+    TosaReference::TensorTemplate<TOut>* output;
     tosa::TosaConvAttribute* attribute;
 };
 
-template <DType InDtype, DType WeightDtype>
+template <DType InDtype, DType WeightDtype, DType AccDtype>
 class OpFullyConnected : public GraphNode
 {
 public:
@@ -178,14 +174,14 @@
     virtual int checkTensorAttributes() final;
     virtual int eval() final;
 
-    static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
     using InEigenType               = typename GetEigenType<InDtype>::type;
     using WeightEigenType           = typename GetEigenType<WeightDtype>::type;
-    using AccEigenType              = typename GetEigenType<AccDtype>::type;
+    using AccEigenType              = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
+    using OutEigenType              = typename GetEigenType<AccDtype>::type;
     using TIn                       = Eigen::Tensor<InEigenType, 2>;
     using TWeight                   = Eigen::Tensor<WeightEigenType, 2>;
-    using TBias                     = Eigen::Tensor<AccEigenType, 1>;
-    using TAcc                      = Eigen::Tensor<AccEigenType, 2>;
+    using TBias                     = Eigen::Tensor<OutEigenType, 1>;
+    using TOut                      = Eigen::Tensor<OutEigenType, 2>;
 
     static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
     static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
@@ -194,12 +190,12 @@
     TosaReference::TensorTemplate<TIn>* input;
     TosaReference::TensorTemplate<TWeight>* weight;
     TosaReference::TensorTemplate<TBias>* bias;
-    TosaReference::TensorTemplate<TAcc>* output;
+    TosaReference::TensorTemplate<TOut>* output;
 
     tosa::TosaFullyConnectedAttribute* attribute;
 };
 
-template <DType Dtype>
+template <DType Dtype, DType AccDtype>
 class OpMatMul : public GraphNode
 {
 public:
@@ -209,11 +205,11 @@
     virtual int checkTensorAttributes() final;
     virtual int eval() final;
 
-    static constexpr DType AccDtype  = GetAccDType<Dtype, Dtype>::value;
     using InEigenType                = typename GetEigenType<Dtype>::type;
-    using AccEigenType               = typename GetEigenType<AccDtype>::type;
+    using AccEigenType               = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
+    using OutEigenType               = typename GetEigenType<AccDtype>::type;
     using TIn                        = Eigen::Tensor<InEigenType, 3>;
-    using TAcc                       = Eigen::Tensor<AccEigenType, 3>;
+    using TOut                       = Eigen::Tensor<OutEigenType, 3>;
     using TInRank2                   = Eigen::Tensor<InEigenType, 2>;
     using TAccRank2                  = Eigen::Tensor<AccEigenType, 2>;
     static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
@@ -222,7 +218,7 @@
 protected:
     TosaReference::TensorTemplate<TIn>* a;
     TosaReference::TensorTemplate<TIn>* b;
-    TosaReference::TensorTemplate<TAcc>* output;
+    TosaReference::TensorTemplate<TOut>* output;
     int64_t N;
     int64_t H;
     int64_t W;
@@ -252,7 +248,7 @@
     tosa::TosaPoolAttribute* attribute;
 };
 
-template <DType InDtype, DType WeightDtype>
+template <DType InDtype, DType WeightDtype, DType AccDtype>
 class OpTransposeConv2d : public GraphNode
 {
 public:
@@ -262,15 +258,14 @@
     virtual int checkTensorAttributes() final;
     virtual int eval() final;
 
-    static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
-
     using InEigenType     = typename GetEigenType<InDtype>::type;
     using WeightEigenType = typename GetEigenType<WeightDtype>::type;
-    using AccEigenType    = typename GetEigenType<AccDtype>::type;
+    using AccEigenType    = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
+    using OutEigenType    = typename GetEigenType<AccDtype>::type;
     using TIn             = Eigen::Tensor<InEigenType, 4>;
     using TWeight         = Eigen::Tensor<WeightEigenType, 4>;
-    using TBias           = Eigen::Tensor<AccEigenType, 1>;
-    using TAcc            = Eigen::Tensor<AccEigenType, 4>;
+    using TBias           = Eigen::Tensor<OutEigenType, 1>;
+    using TOut            = Eigen::Tensor<OutEigenType, 4>;
 
     static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
     static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
@@ -279,7 +274,7 @@
     TosaReference::TensorTemplate<TIn>* input;
     TosaReference::TensorTemplate<TWeight>* weight;
     TosaReference::TensorTemplate<TBias>* bias;
-    TosaReference::TensorTemplate<TAcc>* output;
+    TosaReference::TensorTemplate<TOut>* output;
     TosaTransposeConvAttribute* attribute;
 };