[ONCPUML-951] Variable weight support for Convolution.

API changes for NEGEMMConvolutionLayer and CpuGemmConv2d

Built with:

    scons neon=1 opencl=0 os=linux arch=armv8.2-a multi_isa=1 \
        build=native -j32 Werror=false validation_tests=1 build_dir=opt \
        standalone=1 asserts=1 experimental_fixed_format_kernels=1 .

Tested with:

    ./build/opt/tests/arm_compute_validation

Hardware where the test executable was run:

Neoverse N1

Test coverage:

* NEGEMMConvolutionLayer, CpuGemmConv2d
* NHWC (the only one supported by the fixed-format kernels)
* F16, F32
* Shapes: RunSmall

Change-Id: I4fd3e495a7cbf61210ea02d37440ba9652934e99
Signed-off-by: Francesco Petrogalli <francesco.petrogalli@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7632
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index c021d31..0174d0e 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -99,15 +99,15 @@
 CpuGemmConv2d::~CpuGemmConv2d() = default;
 
 void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const ActivationLayerInfo &act_info,
-                                 bool enable_fast_math, int gemm_3d_depth)
+                                 bool enable_fast_math, int gemm_3d_depth, bool fixed_format, arm_gemm::WeightFormat weight_format)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights);
-    ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth, _skip_im2col));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth, _skip_im2col, fixed_format, weight_format));
 
     // Create GEMMInfo structure
     const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
                                          gemm_3d_depth, _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
-                                         false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info);
+                                         false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList<ITensorInfo *>(), fixed_format, weight_format);
 
     // Supported activations in GEMM
     const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU,
@@ -156,7 +156,8 @@
         quantization::calculate_quantized_multipliers(iqinfo, wqinfo, oqinfo, output_info);
 
         _mm_gemmlowp = std::make_unique<CpuGemmLowpMatrixMultiplyCore>();
-        _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false, enable_fast_math, false, act_info));
+        _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false, enable_fast_math, false, act_info,
+                                                                              experimental::PostOpList<ITensorInfo *>(), fixed_format, weight_format));
 
         auto mm_mem_req = _mm_gemmlowp->workspace();
         for(unsigned int cont = 0; cont < mm_mem_req.size(); ++cont)
@@ -178,7 +179,7 @@
 }
 
 Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
-                                  const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col)
+                                  const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col, bool fixed_format, arm_gemm::WeightFormat weight_format)
 {
     const DataType data_type             = src->data_type();
     const bool     is_quantized          = is_data_type_quantized_asymmetric(data_type);
@@ -187,7 +188,7 @@
     // Create GEMMInfo structure
     const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
                                         gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
-                                        false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info);
+                                        false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList<ITensorInfo *>(), fixed_format, weight_format);
 
     if(is_quantized)
     {
@@ -227,6 +228,7 @@
         std::unique_ptr<ITensorInfo> weights_qa = weights->clone();
         input_qa->set_quantization_info(QuantizationInfo(iqinfo.uniform().scale, -iqinfo.uniform().offset));
         weights_qa->set_quantization_info(QuantizationInfo(wqinfo.uniform().scale, -wqinfo.uniform().offset));
+
         return CpuGemmLowpMatrixMultiplyCore::validate(input_qa.get(), weights_qa.get(), biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, skip_im2col, false, output_info, false, enable_fast_math,
                                                                                                                false, act_info));
     }
@@ -294,6 +296,7 @@
                                                  kernel_height,
                                                  conv_info,
                                                  dilation);
+
     ARM_COMPUTE_ERROR_ON_MSG((dst->dimension(idx_width) != conv_w) || (dst->dimension(idx_height) != conv_h),
                              "Output shape does not match the expected one");
 
@@ -357,7 +360,8 @@
     // Configure GEMM
     // In case we need to skip col2im, GEMM3D (gemm_3d_depth != 0) must be called in order to avoid reshaping the output matrix
     const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0;
-    configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math, gemm_3d_depth);
+    const bool         fixed_format  = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+    configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math, gemm_3d_depth, fixed_format, weights_info.weight_format());
 
     if(!_skip_col2im && _data_layout == DataLayout::NCHW)
     {
@@ -384,6 +388,38 @@
     _aux_mem[GemmOutput]      = MemoryInfo(offset_int_vec(GemmOutput), MemoryLifetime::Temporary, _gemm_output.total_size());
 }
 
+Status CpuGemmConv2d::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
+                                   const PadStrideInfo &conv_info,
+                                   const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, const bool enable_fast_math)
+{
+    const DataLayout   data_layout   = src->data_layout();
+    const int          idx_width     = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+    const int          idx_height    = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+    const unsigned int kernel_width  = weights->dimension(idx_width);
+    const unsigned int kernel_height = weights->dimension(idx_height);
+    unsigned int       conv_w        = 0;
+    unsigned int       conv_h        = 0;
+    std::tie(conv_w, conv_h)         = scaled_dimensions(src->dimension(idx_width),
+                                                         src->dimension(idx_height),
+                                                         kernel_width,
+                                                         kernel_height,
+                                                         conv_info,
+                                                         dilation);
+
+    const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info,
+                                                                              dilation, act_info);
+
+    const bool         skip_im2col   = skip_info.skip_im2col;
+    const bool         skip_col2im   = skip_info.skip_col2im;
+    const unsigned int gemm_3d_depth = skip_col2im ? conv_h : 0;
+    const bool         fixed_format  = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+    const GEMMInfo     gemm_info     = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
+                                                gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
+                                                false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList<ITensorInfo *>(), fixed_format, weights_info.weight_format());
+
+    return CpuGemm::has_opt_impl(expected_weight_format, src, weights, biases, dst, gemm_info);
+}
+
 Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const PadStrideInfo &conv_info,
                                const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups)
 {
@@ -450,7 +486,7 @@
         {
             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, biases);
         }
-        ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(idx_kernels));
+        ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != dst->dimension(idx_channel));
         ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
     }
 
@@ -472,7 +508,7 @@
 
         im2col_reshaped_info = TensorInfo(shape_im2col, 1, data_type);
         im2col_reshaped_info.set_quantization_info(src->quantization_info());
-        ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuIm2ColKernel::validate(src, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation));
+        ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuIm2ColKernel::validate(src, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation, 1));
         gemm_input_to_use = &im2col_reshaped_info;
     }
 
@@ -490,8 +526,11 @@
         info_gemm = TensorInfo(dst->tensor_shape(), 1, output_data_type);
     }
     info_gemm.set_quantization_info(dst->quantization_info()).set_data_layout(src->data_layout());
-    gemm_output_to_use = &info_gemm;
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col));
+    gemm_output_to_use      = &info_gemm;
+    const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col, fixed_format,
+                                            weights_info.weight_format()));
 
     // Validate Col2Im/ReshapeLayer
     if(!skip_col2im && (data_layout == DataLayout::NCHW))
@@ -548,7 +587,10 @@
     // Runs CpuGemm or CpuGemmLowpMatrixMultiplyCore functions
     ITensorPack pack_mm = tensors;
     pack_mm.add_const_tensor(TensorType::ACL_SRC_0, gemm_input_to_use);
-    pack_mm.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
+    if(!this->isVarWeightsKernel())
+    {
+        pack_mm.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
+    }
     pack_mm.add_tensor(TensorType::ACL_DST, gemm_output_to_use);
     if(_is_quantized)
     {
@@ -598,6 +640,15 @@
 {
     if(!_is_prepared)
     {
+        // Variable weights executions that use fixed-format kernels
+        // need no reshaping of the weights.
+        if(this->isVarWeightsKernel())
+        {
+            _is_quantized ? _mm_gemmlowp->prepare(tensors) : _mm_gemm->prepare(tensors);
+            _is_prepared = true;
+            return;
+        }
+
         // Run weights reshaping and mark original weights tensor as unused
         CpuAuxTensorHandler weights_reshaped(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors);
         auto                weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
@@ -608,12 +659,9 @@
         };
         NEScheduler::get().schedule_op(_weights_reshape_kernel.get(), 3, _weights_reshape_kernel->window(), pack);
         weights->mark_as_unused();
-
-        // Prepare GEMM
         ITensorPack gemm_pack = tensors;
         gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, weights_reshaped.get());
         _is_quantized ? _mm_gemmlowp->prepare(gemm_pack) : _mm_gemm->prepare(gemm_pack);
-
         _is_prepared = true;
     }
 }
@@ -621,5 +669,9 @@
 {
     return _aux_mem;
 }
+bool CpuGemmConv2d::isVarWeightsKernel() const
+{
+    return _mm_gemm && _mm_gemm->isVarWeightsKernel();
+}
 } // namespace cpu
 } // namespace arm_compute