COMPMID-3324: Fix per-channel quantization on N blocking
Direct the column to start from in the quantized code
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I8231e0b541c6b1b76becf349a1d6ddf973ade9e2
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3488
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
index d9b1a71..2b936d0 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
@@ -228,7 +228,7 @@
requantize_block_32(_qp, (nmax - n0), (m_end - m_start), result_buffer, (nmax - n0),
this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (m_start * this->_ldc) + n0, this->_ldc,
- local_row_sums, col_bias + (multi * _Nsize) + n0);
+ local_row_sums, col_bias + (multi * _Nsize) + n0, n0);
}
} while (p.next_dim0());
}
diff --git a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
index 18f030f..9957165 100644
--- a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
+++ b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
@@ -115,7 +115,7 @@
_args._Nsize,
this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (first_row * this->_ldc), this->_ldc,
_row_sums + (multi * _args._nbatches * _args._Msize) + (batch * _args._Msize) + first_row,
- _col_sums + (multi * _args._Nsize));
+ _col_sums + (multi * _args._Nsize), 0);
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp
index 00b42cf..53e5527 100644
--- a/src/core/NEON/kernels/arm_gemm/quantized.cpp
+++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp
@@ -57,7 +57,7 @@
template<bool do_shift_correction, bool per_channel>
void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigned int height,
const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
- const int32_t *row_bias, const int32_t *col_bias) {
+ const int32_t *row_bias, const int32_t *col_bias, const unsigned int start_col) {
const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul);
const int32x4_t v_shift = vdupq_n_s32(qp.per_layer_shift);
const int32x4_t v_minval = vdupq_n_s32(qp.minval);
@@ -76,8 +76,8 @@
unsigned int odds=(width % 4);
const int32_t *colptr = col_bias;
- const int32_t *perch_mul_ptr = qp.per_channel_muls;
- const int32_t *perch_shift_ptr = qp.per_channel_shifts;
+ const int32_t *perch_mul_ptr = qp.per_channel_muls + start_col;
+ const int32_t *perch_shift_ptr = qp.per_channel_shifts + start_col;
const int32_t *in_ptr = input + (row * in_stride);
int8_t *out_ptr = output + (row * out_stride);
@@ -461,33 +461,33 @@
template<typename Tin, typename Tout>
void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride,
- const int32_t *row_bias, const int32_t *col_bias) {
+ const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col) {
if (qp.per_channel_requant) {
if (qp.minval >= qp.c_offset) {
requantize_block_32_int<false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
- reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
} else {
requantize_block_32_int<true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
- reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
}
} else {
if (qp.minval >= qp.c_offset) {
requantize_block_32_int<false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
- reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
} else {
requantize_block_32_int<true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
- reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
}
}
}
template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
- const int32_t *row_bias, const int32_t *col_bias);
+ const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col);
template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
const uint32_t *input, unsigned int in_stride, uint8_t *output, unsigned int out_stride,
- const int32_t *row_bias, const int32_t *col_bias);
+ const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col);
/*
* Routine (and helpers) to compute row sums needed for offset correction.
diff --git a/src/core/NEON/kernels/arm_gemm/quantized.hpp b/src/core/NEON/kernels/arm_gemm/quantized.hpp
index a91a888..b0e0c3b 100644
--- a/src/core/NEON/kernels/arm_gemm/quantized.hpp
+++ b/src/core/NEON/kernels/arm_gemm/quantized.hpp
@@ -28,7 +28,7 @@
template<typename Tin, typename Tout>
void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride,
- const int32_t *row_bias, const int32_t *col_bias);
+ const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col);
template<typename T>
void compute_row_sums(const Requantize32 &qp, unsigned int width, unsigned int height,