COMPMID-428: Port NESoftmaxLayer to 16-bit fixed point.

Change-Id: I65122950bab9124b9758c27096c0f458b77aeabb
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79365
Reviewed-by: Moritz Pflanzer <moritz.pflanzer@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Steven Niu <steven.niu@arm.com>
diff --git a/arm_compute/core/FixedPoint.h b/arm_compute/core/FixedPoint.h
index 774125e..f166d93 100644
--- a/arm_compute/core/FixedPoint.h
+++ b/arm_compute/core/FixedPoint.h
@@ -29,6 +29,7 @@
 using qint8_t  = int8_t;  /**< 8 bit fixed point scalar value */
 using qint16_t = int16_t; /**< 16 bit fixed point scalar value */
 using qint32_t = int32_t; /**< 32 bit fixed point scalar value */
+using qint64_t = int64_t; /**< 64 bit fixed point scalar value */
 
 /** 8 bit fixed point scalar saturating shift left
  *
@@ -100,6 +101,15 @@
  */
 qint16_t sqadd_qs16(qint16_t a, qint16_t b);
 
+/** 32 bit fixed point scalar saturating add
+ *
+ * @param[in] a First 32 bit fixed point input
+ * @param[in] b Second 32 bit fixed point input
+ *
+ * @return The result of the 32 bit fixed point addition. The result is saturated in case of overflow
+ */
+qint32_t sqadd_qs32(qint32_t a, qint32_t b);
+
 /** 8 bit fixed point scalar subtraction
  *
  * @param[in] a First 8 bit fixed point input
@@ -332,6 +342,14 @@
  * @return The narrowing conversion to 8 bit
  */
 qint8_t sqmovn_qs16(qint16_t a);
+
+/** Scalar saturating move and narrow.
+ *
+ * @param[in] a Input to convert to 16 bit fixed point
+ *
+ * @return The narrowing conversion to 16 bit
+ */
+qint16_t sqmovn_qs32(qint32_t a);
 }
 #include "arm_compute/core/FixedPoint.inl"
 #endif /* __ARM_COMPUTE_FIXEDPOINT_H__ */
diff --git a/arm_compute/core/FixedPoint.inl b/arm_compute/core/FixedPoint.inl
index fdbc3f0..b921b32 100644
--- a/arm_compute/core/FixedPoint.inl
+++ b/arm_compute/core/FixedPoint.inl
@@ -90,13 +90,22 @@
 
 inline qint16_t sqadd_qs16(qint16_t a, qint16_t b)
 {
-    // We need to store the temporary result in qint16_t otherwise we cannot evaluate the overflow
+    // We need to store the temporary result in qint32_t otherwise we cannot evaluate the overflow
     qint32_t tmp = (static_cast<qint32_t>(a) + static_cast<qint32_t>(b));
 
     // Saturate the result in case of overflow and cast to qint16_t
     return saturate_convert<qint32_t, qint16_t>(tmp);
 }
 
+inline qint32_t sqadd_qs32(qint32_t a, qint32_t b)
+{
+    // We need to store the temporary result in qint64_t otherwise we cannot evaluate the overflow
+    qint64_t tmp = (static_cast<qint64_t>(a) + static_cast<qint64_t>(b));
+
+    // Saturate the result in case of overflow and cast to qint32_t
+    return saturate_convert<qint64_t, qint32_t>(tmp);
+}
+
 inline qint8_t ssub_qs8(qint8_t a, qint8_t b)
 {
     return a - b;
@@ -388,4 +397,10 @@
     // Saturate the result in case of overflow and cast to qint8_t
     return saturate_convert<qint16_t, qint8_t>(a);
 }
+
+inline qint16_t sqmovn_qs32(qint32_t a)
+{
+    // Saturate the result in case of overflow and cast to qint16_t
+    return saturate_convert<qint32_t, qint16_t>(a);
+}
 }
diff --git a/arm_compute/core/NEON/NEFixedPoint.h b/arm_compute/core/NEON/NEFixedPoint.h
index e30509c..09579f9 100644
--- a/arm_compute/core/NEON/NEFixedPoint.h
+++ b/arm_compute/core/NEON/NEFixedPoint.h
@@ -46,6 +46,7 @@
 using qint16x8x2_t = int16x8x2_t; /**< 16 bit fixed point vector with 16 elements */
 using qint16x8x3_t = int16x8x3_t; /**< 16 bit fixed point vector with 24 elements */
 using qint16x8x4_t = int16x8x4_t; /**< 16 bit fixed point vector with 32 elements */
+using qint32x2_t   = int32x2_t;   /**< 32 bit fixed point vector with 2 elements */
 using qint32x4_t   = int32x4_t;   /**< 32 bit fixed point vector with 4 elements */
 
 /** Get the lower half of a 16 elements vector
diff --git a/arm_compute/core/NEON/NEFixedPoint.inl b/arm_compute/core/NEON/NEFixedPoint.inl
index b241dd5..f62a338 100644
--- a/arm_compute/core/NEON/NEFixedPoint.inl
+++ b/arm_compute/core/NEON/NEFixedPoint.inl
@@ -384,6 +384,11 @@
     return vqadd_s16(a, b);
 }
 
+inline qint32x2_t vqadd_qs32(qint32x2_t a, qint32x2_t b)
+{
+    return vqadd_s32(a, b);
+}
+
 inline qint8x16_t vqaddq_qs8(qint8x16_t a, qint8x16_t b)
 {
     return vqaddq_s8(a, b);
@@ -394,6 +399,11 @@
     return vqaddq_s16(a, b);
 }
 
+inline qint32x4_t vqaddq_qs32(qint32x4_t a, qint32x4_t b)
+{
+    return vqaddq_s32(a, b);
+}
+
 inline int16x4_t vpaddl_qs8(qint8x8_t a)
 {
     return vpaddl_s8(a);
@@ -1073,6 +1083,56 @@
     return vshl_s16(x, shift_value);
 }
 
+inline qint8x8_t vqrecip_qs8(qint8x8_t a, int fixed_point_position)
+{
+    // We need two bits to store 2, thus we can only support formats from Q2.5 to Q7.0
+    const qint8x8_t const_48_over_17 = vdup_n_s8(0x5A >> (5 - fixed_point_position));   // 2.823
+    const qint8x8_t const_32_over_17 = vdup_n_s8((0x3C >> (5 - fixed_point_position))); // 1.8823
+    const qint8x8_t const_one        = vdup_n_s8(1 << fixed_point_position);
+
+    // Find shift value
+    const qint8x8_t shift_value = vqneg_s8(vsub_s8(vdup_n_s8(8), vqadd_s8(vclz_s8(a), vdup_n_s8(fixed_point_position))));
+    const qint8x8_t temp        = vqshl_s8(a, shift_value);
+
+    qint8x8_t x = vqadd_s8(const_48_over_17, vqmul_qs8(temp, const_32_over_17, fixed_point_position));
+
+    uint8x8_t set_one = vcgt_s8(x, const_one);
+    x                 = vbsl_s8(set_one, const_one, x);
+
+    // Use three iterations of Newton-Raphson  method to get the result
+    x = vqadd_s8(x, vqmul_qs8(x, vqsub_s8(const_one, vqmul_qs8(temp, x, fixed_point_position)), fixed_point_position));
+    x = vqadd_s8(x, vqmul_qs8(x, vqsub_s8(const_one, vqmul_qs8(temp, x, fixed_point_position)), fixed_point_position));
+    x = vqadd_s8(x, vqmul_qs8(x, vqsub_s8(const_one, vqmul_qs8(temp, x, fixed_point_position)), fixed_point_position));
+
+    return vqshl_s8(x, shift_value);
+}
+
+inline qint16x4_t vqrecip_qs16(qint16x4_t a, int fixed_point_position)
+{
+    // We need two bits to store 2, thus we can only support formats from Q2.13 to Q15.0
+    const qint16x4_t const_48_over_17 = vdup_n_s16(0x5A5A >> (13 - fixed_point_position)); // 2.823
+    const qint16x4_t const_32_over_17 = vdup_n_s16(0x3C3C >> (13 - fixed_point_position)); // 1.8823
+    const qint16x4_t const_one        = vdup_n_s16(1 << fixed_point_position);
+
+    // Find shift value
+    const qint16x4_t shift_value = vqneg_s16(vqsub_s16(vdup_n_s16(8), vqadd_s16(vclz_s16(a), vdup_n_s16(fixed_point_position))));
+    const qint16x4_t temp        = vqshl_s16(a, shift_value);
+
+    qint16x4_t x = vqadd_s16(const_48_over_17, vqmul_qs16(temp, const_32_over_17, fixed_point_position));
+
+    uint16x4_t set_one = vcgt_s16(x, const_one);
+    x                  = vbsl_s16(set_one, const_one, x);
+
+    // Use five iterations of Newton-Raphson  method to get the result
+    x = vqadd_s16(x, vmul_qs16(x, vqsub_s16(const_one, vqmul_qs16(temp, x, fixed_point_position)), fixed_point_position));
+    x = vqadd_s16(x, vmul_qs16(x, vqsub_s16(const_one, vqmul_qs16(temp, x, fixed_point_position)), fixed_point_position));
+    x = vqadd_s16(x, vmul_qs16(x, vqsub_s16(const_one, vqmul_qs16(temp, x, fixed_point_position)), fixed_point_position));
+    x = vqadd_s16(x, vmul_qs16(x, vqsub_s16(const_one, vqmul_qs16(temp, x, fixed_point_position)), fixed_point_position));
+    x = vqadd_s16(x, vmul_qs16(x, vqsub_s16(const_one, vqmul_qs16(temp, x, fixed_point_position)), fixed_point_position));
+
+    return vqshl_s16(x, shift_value);
+}
+
 inline qint8x16_t vrecipq_qs8(qint8x16_t a, int fixed_point_position)
 {
     // We need two bits to store 2, thus we can only support formats from Q2.5 to Q7.0
@@ -1817,7 +1877,7 @@
     qint8x8_t exp2x = vqexp_qs8(vqmul_qs8(const_two, a, fixed_point_position), fixed_point_position);
     qint8x8_t num   = vqsub_qs8(exp2x, const_one);
     qint8x8_t den   = vqadd_qs8(exp2x, const_one);
-    qint8x8_t tanh  = vqmul_qs8(num, vrecip_qs8(den, fixed_point_position), fixed_point_position);
+    qint8x8_t tanh  = vqmul_qs8(num, vqrecip_qs8(den, fixed_point_position), fixed_point_position);
 
     return tanh;
 }
@@ -1830,7 +1890,7 @@
     qint16x4_t exp2x = vqexp_qs16(vqmul_qs16(const_two, a, fixed_point_position), fixed_point_position);
     qint16x4_t num   = vqsub_qs16(exp2x, const_one);
     qint16x4_t den   = vqadd_qs16(exp2x, const_one);
-    qint16x4_t tanh  = vqmul_qs16(num, vrecip_qs16(den, fixed_point_position), fixed_point_position);
+    qint16x4_t tanh  = vqmul_qs16(num, vqrecip_qs16(den, fixed_point_position), fixed_point_position);
 
     return tanh;
 }
diff --git a/arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h b/arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h
index ab626ad..53eef8d 100644
--- a/arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h
+++ b/arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h
@@ -39,7 +39,7 @@
     NELogits1DMaxKernel();
     /** Set the input and output tensors.
      *
-     * @param[in]  input  Source tensor. Data types supported: QS8, F32.
+     * @param[in]  input  Source tensor. Data types supported: QS8/QS16/F32.
      * @param[out] output Destination tensor. Data types supported: same as @p input
      */
     void configure(const ITensor *input, ITensor *output);
@@ -74,7 +74,7 @@
     ~NELogits1DShiftExpSumKernel() = default;
     /** Set the input and output tensors.
      *
-     * @param[in]  input  Source tensor. Data types supported: QS8, F32.
+     * @param[in]  input  Source tensor. Data types supported: QS8/QS16/F32.
      * @param[in]  max    Max values tensor. Data types supported: same as @p input.
      * @param[out] output Destination tensor. Data types supported: same as @p input.
      * @param[out] sum    Sum of 1D logits tensor. Data types supported: same as @p input.
@@ -113,7 +113,7 @@
     ~NELogits1DNormKernel() = default;
     /** Set the input and output tensors.
      *
-     * @param[in]  input  Source tensor. Data types supported: QS8, F32.
+     * @param[in]  input  Source tensor. Data types supported: QS8/QS16/F32.
      * @param[in]  sum    Sum tensor. The number of dimensions should be dim(input)-1. Data types supported: same as @p input.
      * @param[out] output Destination tensor. Data types supported: same as @p input.
      */
diff --git a/arm_compute/runtime/NEON/functions/NESoftmaxLayer.h b/arm_compute/runtime/NEON/functions/NESoftmaxLayer.h
index dc84dec..44a69d8 100644
--- a/arm_compute/runtime/NEON/functions/NESoftmaxLayer.h
+++ b/arm_compute/runtime/NEON/functions/NESoftmaxLayer.h
@@ -50,7 +50,7 @@
     NESoftmaxLayer();
     /** Set the input and output tensors.
      *
-     * @param[in]  input  Source tensor. Data types supported: QS8/F32.
+     * @param[in]  input  Source tensor. Data types supported: QS8/QS16/F32.
      * @param[out] output Destination tensor. Data types supported: same as @p input.
      */
     void configure(ITensor *input, ITensor *output);