IVGCVSW-949 Simplify use of IntialiseArmComputeClTensorData

Change-Id: I556881e34f26e8152feaaba06d99828394872f58
diff --git a/src/backends/ClWorkloads/ClBatchNormalizationFloatWorkload.cpp b/src/backends/ClWorkloads/ClBatchNormalizationFloatWorkload.cpp
index 021734a..d05349b 100644
--- a/src/backends/ClWorkloads/ClBatchNormalizationFloatWorkload.cpp
+++ b/src/backends/ClWorkloads/ClBatchNormalizationFloatWorkload.cpp
@@ -68,10 +68,10 @@
                       m_Gamma.get(),
                       m_Data.m_Parameters.m_Eps);
 
-    InitializeArmComputeClTensorDataForFloatTypes(*m_Mean, m_Data.m_Mean);
-    InitializeArmComputeClTensorDataForFloatTypes(*m_Variance, m_Data.m_Variance);
-    InitializeArmComputeClTensorDataForFloatTypes(*m_Beta, m_Data.m_Beta);
-    InitializeArmComputeClTensorDataForFloatTypes(*m_Gamma, m_Data.m_Gamma);
+    InitializeArmComputeClTensorData(*m_Mean, m_Data.m_Mean);
+    InitializeArmComputeClTensorData(*m_Variance, m_Data.m_Variance);
+    InitializeArmComputeClTensorData(*m_Beta, m_Data.m_Beta);
+    InitializeArmComputeClTensorData(*m_Gamma, m_Data.m_Gamma);
 
     // Force Compute Library to perform the necessary copying and reshaping, after which
     // delete all the input tensors that will no longer be needed
diff --git a/src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp b/src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp
index 029f41d..f0b9a46 100644
--- a/src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp
+++ b/src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp
@@ -52,11 +52,11 @@
                                  &output,
                                  padStrideInfo);
 
-    InitializeArmComputeClTensorDataForFloatTypes(*m_KernelTensor, m_Data.m_Weight);
+    InitializeArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight);
 
     if (m_BiasTensor)
     {
-        InitializeArmComputeClTensorDataForFloatTypes(*m_BiasTensor, m_Data.m_Bias);
+        InitializeArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias);
     }
 
     // Force Compute Library to perform the necessary copying and reshaping, after which
diff --git a/src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp b/src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp
index e6783b6..c9f5eaa 100644
--- a/src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp
+++ b/src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp
@@ -51,11 +51,11 @@
                                  &output,
                                  padStrideInfo);
 
-    InitialiseArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight->GetConstTensor<uint8_t>());
+    InitializeArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight);
 
     if (m_BiasTensor)
     {
-        InitialiseArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias->GetConstTensor<int32_t>());
+        InitializeArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias);
     }
 
     // Force Compute Library to perform the necessary copying and reshaping, after which
diff --git a/src/backends/ClWorkloads/ClDepthwiseConvolutionFloatWorkload.cpp b/src/backends/ClWorkloads/ClDepthwiseConvolutionFloatWorkload.cpp
index 635ae1f..bc3b165 100644
--- a/src/backends/ClWorkloads/ClDepthwiseConvolutionFloatWorkload.cpp
+++ b/src/backends/ClWorkloads/ClDepthwiseConvolutionFloatWorkload.cpp
@@ -17,11 +17,11 @@
     const WorkloadInfo& info)
     : ClDepthwiseConvolutionBaseWorkload(descriptor, info)
 {
-    InitializeArmComputeClTensorDataForFloatTypes(*m_KernelTensor, m_Data.m_Weight);
+    InitializeArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight);
 
     if (m_BiasTensor)
     {
-        InitializeArmComputeClTensorDataForFloatTypes(*m_BiasTensor, m_Data.m_Bias);
+        InitializeArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias);
     }
 
     m_DepthwiseConvolutionLayer->prepare();
diff --git a/src/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.cpp b/src/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.cpp
index af5836e..4ea5590 100644
--- a/src/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.cpp
+++ b/src/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.cpp
@@ -17,11 +17,11 @@
     const WorkloadInfo& info)
     : ClDepthwiseConvolutionBaseWorkload(descriptor, info)
 {
-    InitialiseArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight->template GetConstTensor<uint8_t>());
+    InitializeArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight);
 
     if (m_BiasTensor)
     {
-        InitialiseArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias->template GetConstTensor<int32_t>());
+        InitializeArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias);
     }
 
     m_DepthwiseConvolutionLayer->prepare();
diff --git a/src/backends/ClWorkloads/ClFullyConnectedWorkload.cpp b/src/backends/ClWorkloads/ClFullyConnectedWorkload.cpp
index 8d2fd0e..4686d1c 100644
--- a/src/backends/ClWorkloads/ClFullyConnectedWorkload.cpp
+++ b/src/backends/ClWorkloads/ClFullyConnectedWorkload.cpp
@@ -68,26 +68,11 @@
     fc_info.transpose_weights = m_Data.m_Parameters.m_TransposeWeightMatrix;
     m_FullyConnectedLayer.configure(&input, m_WeightsTensor.get(), m_BiasesTensor.get(), &output, fc_info);
 
-    // Allocate
-    if (m_Data.m_Weight->GetTensorInfo().GetDataType() == DataType::QuantisedAsymm8)
-    {
-        InitialiseArmComputeClTensorData(*m_WeightsTensor, m_Data.m_Weight->GetConstTensor<uint8_t>());
-    }
-    else
-    {
-        InitializeArmComputeClTensorDataForFloatTypes(*m_WeightsTensor, m_Data.m_Weight);
-    }
+    InitializeArmComputeClTensorData(*m_WeightsTensor, m_Data.m_Weight);
 
     if (m_BiasesTensor)
     {
-        if (m_Data.m_Bias->GetTensorInfo().GetDataType() == DataType::Signed32)
-        {
-            InitialiseArmComputeClTensorData(*m_BiasesTensor, m_Data.m_Bias->GetConstTensor<int32_t>());
-        }
-        else
-        {
-            InitializeArmComputeClTensorDataForFloatTypes(*m_BiasesTensor, m_Data.m_Bias);
-        }
+        InitializeArmComputeClTensorData(*m_BiasesTensor, m_Data.m_Bias);
     }
 
     // Force Compute Library to perform the necessary copying and reshaping, after which
diff --git a/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp b/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp
index 09a34c2..8e2c875 100644
--- a/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp
+++ b/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp
@@ -172,57 +172,40 @@
 
     armcomputetensorutils::InitialiseArmComputeTensorEmpty(*m_ScratchBuffer);
 
-    InitialiseArmComputeClTensorData(*m_InputToForgetWeightsTensor,
-                                     m_Data.m_InputToForgetWeights->GetConstTensor<float>());
-    InitialiseArmComputeClTensorData(*m_InputToCellWeightsTensor,
-                                     m_Data.m_InputToCellWeights->GetConstTensor<float>());
-    InitialiseArmComputeClTensorData(*m_InputToOutputWeightsTensor,
-                                     m_Data.m_InputToOutputWeights->GetConstTensor<float>());
-    InitialiseArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor,
-                                     m_Data.m_RecurrentToForgetWeights->GetConstTensor<float>());
-    InitialiseArmComputeClTensorData(*m_RecurrentToCellWeightsTensor,
-                                     m_Data.m_RecurrentToCellWeights->GetConstTensor<float>());
-    InitialiseArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor,
-                                     m_Data.m_RecurrentToOutputWeights->GetConstTensor<float>());
-    InitialiseArmComputeClTensorData(*m_ForgetGateBiasTensor,
-                                     m_Data.m_ForgetGateBias->GetConstTensor<float>());
-    InitialiseArmComputeClTensorData(*m_CellBiasTensor,
-                                     m_Data.m_CellBias->GetConstTensor<float>());
-    InitialiseArmComputeClTensorData(*m_OutputGateBiasTensor,
-                                     m_Data.m_OutputGateBias->GetConstTensor<float>());
+    InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
+    InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights);
+    InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
+    InitializeArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
+    InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights);
+    InitializeArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
+    InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
+    InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias);
+    InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
 
     if (!m_Data.m_Parameters.m_CifgEnabled)
     {
-        InitialiseArmComputeClTensorData(*m_InputToInputWeightsTensor,
-                                         m_Data.m_InputToInputWeights->GetConstTensor<float>());
-        InitialiseArmComputeClTensorData(*m_RecurrentToInputWeightsTensor,
-                                         m_Data.m_RecurrentToInputWeights->GetConstTensor<float>());
+        InitializeArmComputeClTensorData(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights);
+        InitializeArmComputeClTensorData(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights);
         if (m_Data.m_CellToInputWeights != nullptr)
         {
-            InitialiseArmComputeClTensorData(*m_CellToInputWeightsTensor,
-                                             m_Data.m_CellToInputWeights->GetConstTensor<float>());
+            InitializeArmComputeClTensorData(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights);
         }
-        InitialiseArmComputeClTensorData(*m_InputGateBiasTensor,
-                                         m_Data.m_InputGateBias->GetConstTensor<float>());
+        InitializeArmComputeClTensorData(*m_InputGateBiasTensor, m_Data.m_InputGateBias);
     }
 
     if (m_Data.m_Parameters.m_ProjectionEnabled)
     {
-        InitialiseArmComputeClTensorData(*m_ProjectionWeightsTensor,
-                                         m_Data.m_ProjectionWeights->GetConstTensor<float>());
+        InitializeArmComputeClTensorData(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights);
         if (m_Data.m_ProjectionBias != nullptr)
         {
-            InitialiseArmComputeClTensorData(*m_ProjectionBiasTensor,
-                                             m_Data.m_ProjectionBias->GetConstTensor<float>());
+            InitializeArmComputeClTensorData(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias);
         }
     }
 
     if (m_Data.m_Parameters.m_PeepholeEnabled)
     {
-        InitialiseArmComputeClTensorData(*m_CellToForgetWeightsTensor,
-                                         m_Data.m_CellToForgetWeights->GetConstTensor<float>());
-        InitialiseArmComputeClTensorData(*m_CellToOutputWeightsTensor,
-                                         m_Data.m_CellToOutputWeights->GetConstTensor<float>());
+        InitializeArmComputeClTensorData(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights);
+        InitializeArmComputeClTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
     }
 
     // Force Compute Library to perform the necessary copying and reshaping, after which
diff --git a/src/backends/ClWorkloads/ClWorkloadUtils.hpp b/src/backends/ClWorkloads/ClWorkloadUtils.hpp
index 6f1b155..a10237c 100644
--- a/src/backends/ClWorkloads/ClWorkloadUtils.hpp
+++ b/src/backends/ClWorkloads/ClWorkloadUtils.hpp
@@ -42,8 +42,8 @@
     CopyArmComputeClTensorData<T>(data, clTensor);
 }
 
-inline void InitializeArmComputeClTensorDataForFloatTypes(arm_compute::CLTensor& clTensor,
-                                                          const ConstCpuTensorHandle *handle)
+inline void InitializeArmComputeClTensorData(arm_compute::CLTensor& clTensor,
+                                             const ConstCpuTensorHandle* handle)
 {
     BOOST_ASSERT(handle);
     switch(handle->GetTensorInfo().GetDataType())
@@ -54,8 +54,14 @@
         case DataType::Float32:
             InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<float>());
             break;
+        case DataType::QuantisedAsymm8:
+            InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<uint8_t>());
+            break;
+        case DataType::Signed32:
+            InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<int32_t>());
+            break;
         default:
-            BOOST_ASSERT_MSG(false, "Unexpected floating point type.");
+            BOOST_ASSERT_MSG(false, "Unexpected tensor type.");
     }
 };