IVGCVSW-4511 Add BFloat16 to RefLayerSupport and unit tests
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Ifaae4d5aac468ba927b2c6a4bf31b8c8522aeb2e
diff --git a/src/backends/reference/workloads/Pad.cpp b/src/backends/reference/workloads/Pad.cpp
index 9fedb44..ffdd469 100644
--- a/src/backends/reference/workloads/Pad.cpp
+++ b/src/backends/reference/workloads/Pad.cpp
@@ -152,6 +152,13 @@
}
}
+template void Pad<BFloat16>(const TensorInfo& inputInfo,
+ const TensorInfo& outputInfo,
+ std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
+ const BFloat16* inputData,
+ BFloat16* outData,
+ const float padValue);
+
template void Pad<float>(const TensorInfo& inputInfo,
const TensorInfo& outputInfo,
std::vector<std::pair<unsigned int, unsigned int>> m_PadList,
diff --git a/src/backends/reference/workloads/RefPadWorkload.cpp b/src/backends/reference/workloads/RefPadWorkload.cpp
index 356f6b1..777682d 100644
--- a/src/backends/reference/workloads/RefPadWorkload.cpp
+++ b/src/backends/reference/workloads/RefPadWorkload.cpp
@@ -33,6 +33,7 @@
Pad(inputInfo, outputInfo, m_Data.m_Parameters.m_PadList, inputData, outputData, m_Data.m_Parameters.m_PadValue);
}
+template class RefPadWorkload<DataType::BFloat16>;
template class RefPadWorkload<DataType::Float32>;
template class RefPadWorkload<DataType::Float16>;
template class RefPadWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefPadWorkload.hpp b/src/backends/reference/workloads/RefPadWorkload.hpp
index 28fb553..5134ac8 100644
--- a/src/backends/reference/workloads/RefPadWorkload.hpp
+++ b/src/backends/reference/workloads/RefPadWorkload.hpp
@@ -30,6 +30,7 @@
void Execute() const override;
};
+using RefPadBFloat16Workload = RefPadWorkload<DataType::BFloat16>;
using RefPadFloat32Workload = RefPadWorkload<DataType::Float32>;
using RefPadFloat16Workload = RefPadWorkload<DataType::Float16>;
using RefPadQAsymm8Workload = RefPadWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefPermuteWorkload.cpp b/src/backends/reference/workloads/RefPermuteWorkload.cpp
index d0e1431..5751ed8 100644
--- a/src/backends/reference/workloads/RefPermuteWorkload.cpp
+++ b/src/backends/reference/workloads/RefPermuteWorkload.cpp
@@ -28,6 +28,7 @@
src->Map(), dst->Map(), sizeof(T));
}
+template class RefPermuteWorkload<DataType::BFloat16>;
template class RefPermuteWorkload<DataType::Float16>;
template class RefPermuteWorkload<DataType::Float32>;
template class RefPermuteWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefPermuteWorkload.hpp b/src/backends/reference/workloads/RefPermuteWorkload.hpp
index 00a3385..a8d308e 100644
--- a/src/backends/reference/workloads/RefPermuteWorkload.hpp
+++ b/src/backends/reference/workloads/RefPermuteWorkload.hpp
@@ -27,6 +27,7 @@
void Execute() const override;
};
+using RefPermuteBFloat16Workload = RefPermuteWorkload<DataType::BFloat16>;
using RefPermuteFloat16Workload = RefPermuteWorkload<DataType::Float16>;
using RefPermuteFloat32Workload = RefPermuteWorkload<DataType::Float32>;
using RefPermuteQAsymm8Workload = RefPermuteWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefTransposeWorkload.cpp b/src/backends/reference/workloads/RefTransposeWorkload.cpp
index 6bdfb21..242668b 100644
--- a/src/backends/reference/workloads/RefTransposeWorkload.cpp
+++ b/src/backends/reference/workloads/RefTransposeWorkload.cpp
@@ -27,6 +27,7 @@
armnnUtils::Transpose(GetTensorInfo(src).GetShape(), mappings, src->Map(), dst->Map(), sizeof(T));
}
+template class RefTransposeWorkload<DataType::BFloat16>;
template class RefTransposeWorkload<DataType::Float16>;
template class RefTransposeWorkload<DataType::Float32>;
template class RefTransposeWorkload<DataType::QAsymmU8>;
diff --git a/src/backends/reference/workloads/RefTransposeWorkload.hpp b/src/backends/reference/workloads/RefTransposeWorkload.hpp
index 4b1c3d3..dcfe618 100644
--- a/src/backends/reference/workloads/RefTransposeWorkload.hpp
+++ b/src/backends/reference/workloads/RefTransposeWorkload.hpp
@@ -27,6 +27,7 @@
void Execute() const override;
};
+using RefTransposeBFloat16Workload = RefTransposeWorkload<DataType::BFloat16>;
using RefTransposeFloat16Workload = RefTransposeWorkload<DataType::Float16>;
using RefTransposeFloat32Workload = RefTransposeWorkload<DataType::Float32>;
using RefTransposeQAsymm8Workload = RefTransposeWorkload<DataType::QAsymmU8>;