IVGCVSW-949 Simplify use of IntialiseArmComputeClTensorData
Change-Id: I556881e34f26e8152feaaba06d99828394872f58
diff --git a/src/backends/ClWorkloads/ClBatchNormalizationFloatWorkload.cpp b/src/backends/ClWorkloads/ClBatchNormalizationFloatWorkload.cpp
index 021734a..d05349b 100644
--- a/src/backends/ClWorkloads/ClBatchNormalizationFloatWorkload.cpp
+++ b/src/backends/ClWorkloads/ClBatchNormalizationFloatWorkload.cpp
@@ -68,10 +68,10 @@
m_Gamma.get(),
m_Data.m_Parameters.m_Eps);
- InitializeArmComputeClTensorDataForFloatTypes(*m_Mean, m_Data.m_Mean);
- InitializeArmComputeClTensorDataForFloatTypes(*m_Variance, m_Data.m_Variance);
- InitializeArmComputeClTensorDataForFloatTypes(*m_Beta, m_Data.m_Beta);
- InitializeArmComputeClTensorDataForFloatTypes(*m_Gamma, m_Data.m_Gamma);
+ InitializeArmComputeClTensorData(*m_Mean, m_Data.m_Mean);
+ InitializeArmComputeClTensorData(*m_Variance, m_Data.m_Variance);
+ InitializeArmComputeClTensorData(*m_Beta, m_Data.m_Beta);
+ InitializeArmComputeClTensorData(*m_Gamma, m_Data.m_Gamma);
// Force Compute Library to perform the necessary copying and reshaping, after which
// delete all the input tensors that will no longer be needed
diff --git a/src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp b/src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp
index 029f41d..f0b9a46 100644
--- a/src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp
+++ b/src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp
@@ -52,11 +52,11 @@
&output,
padStrideInfo);
- InitializeArmComputeClTensorDataForFloatTypes(*m_KernelTensor, m_Data.m_Weight);
+ InitializeArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight);
if (m_BiasTensor)
{
- InitializeArmComputeClTensorDataForFloatTypes(*m_BiasTensor, m_Data.m_Bias);
+ InitializeArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias);
}
// Force Compute Library to perform the necessary copying and reshaping, after which
diff --git a/src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp b/src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp
index e6783b6..c9f5eaa 100644
--- a/src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp
+++ b/src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp
@@ -51,11 +51,11 @@
&output,
padStrideInfo);
- InitialiseArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight->GetConstTensor<uint8_t>());
+ InitializeArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight);
if (m_BiasTensor)
{
- InitialiseArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias->GetConstTensor<int32_t>());
+ InitializeArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias);
}
// Force Compute Library to perform the necessary copying and reshaping, after which
diff --git a/src/backends/ClWorkloads/ClDepthwiseConvolutionFloatWorkload.cpp b/src/backends/ClWorkloads/ClDepthwiseConvolutionFloatWorkload.cpp
index 635ae1f..bc3b165 100644
--- a/src/backends/ClWorkloads/ClDepthwiseConvolutionFloatWorkload.cpp
+++ b/src/backends/ClWorkloads/ClDepthwiseConvolutionFloatWorkload.cpp
@@ -17,11 +17,11 @@
const WorkloadInfo& info)
: ClDepthwiseConvolutionBaseWorkload(descriptor, info)
{
- InitializeArmComputeClTensorDataForFloatTypes(*m_KernelTensor, m_Data.m_Weight);
+ InitializeArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight);
if (m_BiasTensor)
{
- InitializeArmComputeClTensorDataForFloatTypes(*m_BiasTensor, m_Data.m_Bias);
+ InitializeArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias);
}
m_DepthwiseConvolutionLayer->prepare();
diff --git a/src/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.cpp b/src/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.cpp
index af5836e..4ea5590 100644
--- a/src/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.cpp
+++ b/src/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.cpp
@@ -17,11 +17,11 @@
const WorkloadInfo& info)
: ClDepthwiseConvolutionBaseWorkload(descriptor, info)
{
- InitialiseArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight->template GetConstTensor<uint8_t>());
+ InitializeArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight);
if (m_BiasTensor)
{
- InitialiseArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias->template GetConstTensor<int32_t>());
+ InitializeArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias);
}
m_DepthwiseConvolutionLayer->prepare();
diff --git a/src/backends/ClWorkloads/ClFullyConnectedWorkload.cpp b/src/backends/ClWorkloads/ClFullyConnectedWorkload.cpp
index 8d2fd0e..4686d1c 100644
--- a/src/backends/ClWorkloads/ClFullyConnectedWorkload.cpp
+++ b/src/backends/ClWorkloads/ClFullyConnectedWorkload.cpp
@@ -68,26 +68,11 @@
fc_info.transpose_weights = m_Data.m_Parameters.m_TransposeWeightMatrix;
m_FullyConnectedLayer.configure(&input, m_WeightsTensor.get(), m_BiasesTensor.get(), &output, fc_info);
- // Allocate
- if (m_Data.m_Weight->GetTensorInfo().GetDataType() == DataType::QuantisedAsymm8)
- {
- InitialiseArmComputeClTensorData(*m_WeightsTensor, m_Data.m_Weight->GetConstTensor<uint8_t>());
- }
- else
- {
- InitializeArmComputeClTensorDataForFloatTypes(*m_WeightsTensor, m_Data.m_Weight);
- }
+ InitializeArmComputeClTensorData(*m_WeightsTensor, m_Data.m_Weight);
if (m_BiasesTensor)
{
- if (m_Data.m_Bias->GetTensorInfo().GetDataType() == DataType::Signed32)
- {
- InitialiseArmComputeClTensorData(*m_BiasesTensor, m_Data.m_Bias->GetConstTensor<int32_t>());
- }
- else
- {
- InitializeArmComputeClTensorDataForFloatTypes(*m_BiasesTensor, m_Data.m_Bias);
- }
+ InitializeArmComputeClTensorData(*m_BiasesTensor, m_Data.m_Bias);
}
// Force Compute Library to perform the necessary copying and reshaping, after which
diff --git a/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp b/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp
index 09a34c2..8e2c875 100644
--- a/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp
+++ b/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp
@@ -172,57 +172,40 @@
armcomputetensorutils::InitialiseArmComputeTensorEmpty(*m_ScratchBuffer);
- InitialiseArmComputeClTensorData(*m_InputToForgetWeightsTensor,
- m_Data.m_InputToForgetWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_InputToCellWeightsTensor,
- m_Data.m_InputToCellWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_InputToOutputWeightsTensor,
- m_Data.m_InputToOutputWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor,
- m_Data.m_RecurrentToForgetWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_RecurrentToCellWeightsTensor,
- m_Data.m_RecurrentToCellWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor,
- m_Data.m_RecurrentToOutputWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_ForgetGateBiasTensor,
- m_Data.m_ForgetGateBias->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_CellBiasTensor,
- m_Data.m_CellBias->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_OutputGateBiasTensor,
- m_Data.m_OutputGateBias->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
+ InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights);
+ InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
+ InitializeArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
+ InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights);
+ InitializeArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
+ InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
+ InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias);
+ InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
if (!m_Data.m_Parameters.m_CifgEnabled)
{
- InitialiseArmComputeClTensorData(*m_InputToInputWeightsTensor,
- m_Data.m_InputToInputWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_RecurrentToInputWeightsTensor,
- m_Data.m_RecurrentToInputWeights->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights);
+ InitializeArmComputeClTensorData(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights);
if (m_Data.m_CellToInputWeights != nullptr)
{
- InitialiseArmComputeClTensorData(*m_CellToInputWeightsTensor,
- m_Data.m_CellToInputWeights->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights);
}
- InitialiseArmComputeClTensorData(*m_InputGateBiasTensor,
- m_Data.m_InputGateBias->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_InputGateBiasTensor, m_Data.m_InputGateBias);
}
if (m_Data.m_Parameters.m_ProjectionEnabled)
{
- InitialiseArmComputeClTensorData(*m_ProjectionWeightsTensor,
- m_Data.m_ProjectionWeights->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights);
if (m_Data.m_ProjectionBias != nullptr)
{
- InitialiseArmComputeClTensorData(*m_ProjectionBiasTensor,
- m_Data.m_ProjectionBias->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias);
}
}
if (m_Data.m_Parameters.m_PeepholeEnabled)
{
- InitialiseArmComputeClTensorData(*m_CellToForgetWeightsTensor,
- m_Data.m_CellToForgetWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_CellToOutputWeightsTensor,
- m_Data.m_CellToOutputWeights->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights);
+ InitializeArmComputeClTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
}
// Force Compute Library to perform the necessary copying and reshaping, after which
diff --git a/src/backends/ClWorkloads/ClWorkloadUtils.hpp b/src/backends/ClWorkloads/ClWorkloadUtils.hpp
index 6f1b155..a10237c 100644
--- a/src/backends/ClWorkloads/ClWorkloadUtils.hpp
+++ b/src/backends/ClWorkloads/ClWorkloadUtils.hpp
@@ -42,8 +42,8 @@
CopyArmComputeClTensorData<T>(data, clTensor);
}
-inline void InitializeArmComputeClTensorDataForFloatTypes(arm_compute::CLTensor& clTensor,
- const ConstCpuTensorHandle *handle)
+inline void InitializeArmComputeClTensorData(arm_compute::CLTensor& clTensor,
+ const ConstCpuTensorHandle* handle)
{
BOOST_ASSERT(handle);
switch(handle->GetTensorInfo().GetDataType())
@@ -54,8 +54,14 @@
case DataType::Float32:
InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<float>());
break;
+ case DataType::QuantisedAsymm8:
+ InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<uint8_t>());
+ break;
+ case DataType::Signed32:
+ InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<int32_t>());
+ break;
default:
- BOOST_ASSERT_MSG(false, "Unexpected floating point type.");
+ BOOST_ASSERT_MSG(false, "Unexpected tensor type.");
}
};