COMPMID-421: Added FP16 support to the NEON Direct Convolution function.
Change-Id: I3a1aa2ce985ecf95fc5f441a6e6d43b4935306ee
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79965
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerBiasAccumulateKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerBiasAccumulateKernel.cpp
index effc50e..fb16c8d 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerBiasAccumulateKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerBiasAccumulateKernel.cpp
@@ -100,6 +100,25 @@
return vqaddq_qs16(x, y);
}
+#ifdef ARM_COMPUTE_ENABLE_FP16
+inline float16x8_t internal_vld1q(const float16_t *in)
+{
+ return vld1q_f16(in);
+}
+inline void internal_vst1q(float16_t *p, const float16x8_t &v)
+{
+ vst1q_f16(p, v);
+}
+inline float16x8_t internal_vdupq_n(float16_t v)
+{
+ return vdupq_n_f16(v);
+}
+inline float16x8_t internal_vqaddq(const float16x8_t &x, const float16x8_t &y)
+{
+ return vaddq_f16(x, y);
+}
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
template <typename T1, typename T2, bool in_place>
void accumulate_bias(ITensor *input, const ITensor *bias, const Window window, ITensor *output)
{
@@ -143,8 +162,8 @@
void NEDirectConvolutionLayerBiasAccumulateKernel::configure(ITensor *input, const ITensor *bias, ITensor *output)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::QS8, DataType::QS16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON(input->info()->fixed_point_position() != bias->info()->fixed_point_position());
if(output != nullptr)
{
@@ -183,6 +202,12 @@
{
_func = (output == nullptr) ? &accumulate_bias<float, float, true> : &accumulate_bias<float, float, false>;
}
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ else if(input->info()->data_type() == DataType::F16)
+ {
+ _func = (output == nullptr) ? &accumulate_bias<float16_t, float16_t, true> : &accumulate_bias<float16_t, float16_t, false>;
+ }
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
else if(input->info()->data_type() == DataType::QS8)
{
_func = (output == nullptr) ? &accumulate_bias<qint8_t, qint8_t, true> : &accumulate_bias<qint8_t, qint8_t, false>;