blob: 55d950ff4a21eb6ae3f9fe5d0405674b186cab02 [file] [log] [blame]
Manuel Bottini29599d02021-07-06 15:01:35 +01001/*
Gunes Bayirbf053732024-03-04 14:55:24 +00002 * Copyright (c) 2021-2024 Arm Limited.
Manuel Bottini29599d02021-07-06 15:01:35 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitas7891a732021-08-20 21:39:25 +010024#include "src/cpu/operators/CpuGemmConv2d.h"
Manuel Bottini29599d02021-07-06 15:01:35 +010025
26#include "arm_compute/core/Size2D.h"
27#include "arm_compute/core/TensorInfo.h"
28#include "arm_compute/core/Utils.h"
Manuel Bottini29599d02021-07-06 15:01:35 +010029#include "arm_compute/core/utils/misc/ShapeCalculator.h"
30#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010031#include "arm_compute/core/Validate.h"
Manuel Bottini29599d02021-07-06 15:01:35 +010032#include "arm_compute/runtime/NEON/NEScheduler.h"
33
ramelg013ae3d882021-09-12 23:07:47 +010034#include "src/common/utils/Log.h"
SiCong Lic5ab4df2023-10-17 17:38:57 +010035#include "src/core/helpers/AutoConfiguration.h"
Manuel Bottini29599d02021-07-06 15:01:35 +010036#include "src/core/helpers/MemoryHelpers.h"
SiCong Lic5ab4df2023-10-17 17:38:57 +010037#include "src/core/helpers/Utils.h"
Georgios Pinitas7891a732021-08-20 21:39:25 +010038#include "src/cpu/kernels/CpuCol2ImKernel.h"
39#include "src/cpu/kernels/CpuIm2ColKernel.h"
Georgios Pinitas7891a732021-08-20 21:39:25 +010040#include "src/cpu/kernels/CpuWeightsReshapeKernel.h"
41#include "src/cpu/operators/CpuGemm.h"
42#include "src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h"
43#include "src/cpu/operators/CpuGemmLowpOutputStage.h"
Anitha Raj082630b2023-08-22 15:46:27 +010044#include "src/cpu/operators/CpuReshape.h"
Georgios Pinitas7891a732021-08-20 21:39:25 +010045#include "src/cpu/utils/CpuAuxTensorHandler.h"
Manuel Bottini29599d02021-07-06 15:01:35 +010046
47#include <set>
48#include <tuple>
49
50using namespace arm_compute::misc::shape_calculator;
51using namespace arm_compute::experimental;
52
53namespace arm_compute
54{
55namespace cpu
56{
SiCong Lic5ab4df2023-10-17 17:38:57 +010057
58/** @section note_CpuGemmConv2d_weight_transformation Weight Transformations in CpuGemmConv2d
59 *
60 * A. Terminology
61 * Throughout CpuGemmConv2d, we use the following terms in ways that may differ from other operators / kernels:
62 * - "Transform" or "Reshape" of the weights: they both mean all the operations that we perform on the weight
63 * tensor up until they are consumed by gemm (CpuGemm or CpuGemmLowpMatrixMultiplyCore)
64 * Note that the specific gemm operator may perform further transformations on the weights, but the
65 * transformations here only mean those performed in CpuGemmConv2d
66 * - "Transpose" of weights: The @ref CpuTranspose operation. I.e. transpose of the weights' lowest two
67 * dimensions
68 *
69 * B. Gemm-based conv2d
70 * We want to convert the 2d convolution op (ignoring bias):
71 * dst = conv2d(src, weight)
72 * into a matrix multiplication op:
73 * gemm_dst = gemm(lhs, rhs)
74 *
75 * E.g.: For data layout NHWC
76 * 3 (hi) <----------> (lo) 0
77 * src.shape = [batch, in_h , in_w, in_c]
78 * weight.shape = [out_c, k_h , k_w, in_c]
79 * dst.shape = [batch, out_h, out_w, out_c]
80 *
81 * This requires three transformations:
82 * * src -> lhs, transform conv input to gemm lhs; gemm_lhs is a 2d matrix where each row (or column,
83 * depending on the convention) is a linearized "patch" of the conv_input that corresponds to
84 * the receptive field of the corresponding output element.
85 * The convention is to use "column", but to disambiguate from the column vector of a matrix,
86 * in this documentation we shall use "patch".
87 * This transform is called im2col (for details see @ref CpuIm2ColKernel)
88 * * weight -> rhs, transform conv weight to gemm rhs, known as weight transform/reshape (wt)
89 * * gemm_dst -> dst, transform gemm output back to conv output, known as col2im (for details see
90 * @ref CpuCol2ImKernel)
91 *
92 * This section focuses on the weight transformation and assumes the im2col is already performed
93 *
94 * C. Weight Transformation
95 * After im2col, assume: lhs.shape = [num_patch, patch_size],
96 * where patch_size is the number of elements in a "patch": patch_size = k_h * k_w * in_c
97 * num_patch is the number of patches; we can ignore it here (for details see @ref CpuIm2ColKernel)
98 *
99 * After wt, rhs should have the shape: rhs = [patch_size, out_c]
100 *
101 * Therefore, the weight transformation consists of two steps:
102 * 1. Collapsing all 3 spatial dimensions: [out_c, k_h, k_w, in_c] -> [out_c, patch_size]
103 * 2. Transpose the collapsed shape: [out_c, patch_size] -> [patch_size, out_c]
104 *
105 * D. Implementation
106 * There are 4 paths for weight transformation
107 *
108 * 1. Path 1: Fixed weight format - no transformation
109 * The underlying gemm kernel may adopt fixed weight format (isVarWeightsKernel() == true), which requires
110 * that no weight transformation shall be performed
111 * Note that this no-transform requirement applies both to this op (CpuGemmConv2d) and the constituent ops, up
112 * until the fixed format kernels themselves
113 *
114 * 2. Path 2: Reinterpret then transpose later
115 * If the weight tensor has no "holes" (see @ref has_holes), there are two optimizations we can apply:
116 * - We can ignore the first step (collapsing of spatial dimensions) by simply re-interpreting the shape
117 * in TensorInfo
118 * - Instead of performing transpose here, we can pass the transpose flag to the underlying gemm. The gemm
119 * may then decide to fuse the transpose with any further transformations
120 *
121 * 3. Path 3: Reshape then transpose later
122 * If the weight tensor has holes, then we use a dedicated @ref CpuReshape, followed by transpose later
123 *
124 * 4. Path 4: Fused reshape and transpose
125 * This is only for quantized types for now (TODO: Remove (COMPMID-6596)). We fall back to a legacy
126 * non-optimized kernel @ref CpuWeightsReshapeKernel to perform a fused reshape + transpose
127 *
128 * Path 1 is the long term solution that we shall migrate to once (if) we adopt fixed weight format for all gemm
129 * kernels.
130 * In the short term, Path 2 is the favored, more performant path.
131 */
132
133namespace
134{
135/** Initialize reshaped / transformed weight info
136 *
137 * @param[in] weights Input weights
138 * @param[out] reshaped_weights Transformed weights
139 */
140void initialize_reshaped_weight_info(const ITensorInfo &weights, ITensorInfo &reshaped_weights)
141{
142 auto_init_if_empty(reshaped_weights, weights);
143 if (is_data_type_quantized(weights.data_type()))
144 {
145 // WT method: FusedReshapeAndTranspose
146 reshaped_weights.set_tensor_shape(compute_weights_reshaped_shape(weights, /* has_bias */ false));
147 }
148 else
149 {
150 TensorShape collapsed_weights = weights.tensor_shape();
151 collapsed_weights.collapse(3);
152 reshaped_weights.set_tensor_shape(collapsed_weights);
153 }
154}
155} // namespace
156
157CpuGemmConv2d::WeightTransformMethod CpuGemmConv2d::get_wt_method(const ITensorInfo &weights)
158{
159 // TODO: Extend ReinterpretThenTranspose support for quantized data types COMPMID-6596
160 if (is_data_type_quantized(weights.data_type()))
161 {
162 return WeightTransformMethod::FusedReshapeAndTranspose;
163 }
164 return has_holes(weights) ? WeightTransformMethod::ReshapeThenTranspose
165 : WeightTransformMethod::ReinterpretThenTranspose;
166}
167
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100168CpuGemmConv2d::SkipInfo CpuGemmConv2d::skip_im_col_info(const ITensorInfo *src,
169 const ITensorInfo *weights,
170 const PadStrideInfo &conv_info,
171 const Size2D &dilation,
172 const ActivationLayerInfo &act_info)
Francesco.Petrogalli@arm.comfa6877f2022-04-13 09:28:25 +0000173{
174 const DataLayout data_layout = src->data_layout();
175 const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
176 const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
177 const unsigned int kernel_width = weights->dimension(idx_width);
178 const unsigned int kernel_height = weights->dimension(idx_height);
179 unsigned int conv_w = 0;
180 unsigned int conv_h = 0;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100181 std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width), src->dimension(idx_height), kernel_width,
182 kernel_height, conv_info, dilation);
183 const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 &&
184 conv_info.stride().first == 1 && conv_info.stride().second == 1);
Francesco.Petrogalli@arm.comfa6877f2022-04-13 09:28:25 +0000185
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100186 if (skip_im2col)
Francesco.Petrogalli@arm.comfa6877f2022-04-13 09:28:25 +0000187 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100188 const bool skip_col2im =
189 (data_layout == DataLayout::NHWC &&
190 (bool(CpuGemmConv2d::validate_gemm3d(src, weights, act_info, conv_h, /*skip_im2col*/ true))));
191 if (skip_col2im)
Francesco.Petrogalli@arm.comfa6877f2022-04-13 09:28:25 +0000192 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100193 return {true, true};
Francesco.Petrogalli@arm.comfa6877f2022-04-13 09:28:25 +0000194 }
195 }
196 else
197 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100198 const bool skip_col2im =
199 (data_layout == DataLayout::NHWC &&
200 (bool(CpuGemmConv2d::validate_gemm3d(src, weights, act_info, conv_h, /*skip_im2col*/ false))));
201 if (skip_col2im)
Francesco.Petrogalli@arm.comfa6877f2022-04-13 09:28:25 +0000202 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100203 return {false, true};
Francesco.Petrogalli@arm.comfa6877f2022-04-13 09:28:25 +0000204 }
205 }
206
207 // Default case when we cannot reinterpret the input and output as 3D.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100208 return {false, false};
Francesco.Petrogalli@arm.comfa6877f2022-04-13 09:28:25 +0000209}
210
Georgios Pinitas19884632021-08-16 12:38:54 +0100211CpuGemmConv2d::CpuGemmConv2d()
SiCong Lic5ab4df2023-10-17 17:38:57 +0100212 : _weights_reshape(nullptr),
213 _weights_reshape_and_transpose_kernel(nullptr),
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100214 _im2col_kernel(),
215 _mm_gemm(),
216 _mm_gemmlowp(),
217 _col2im_kernel(),
218 _reshape(),
219 _im2col_output(),
220 _weights_reshaped(),
221 _gemm_output(),
222 _gemm_output_3d(),
223 _data_layout(DataLayout::NCHW),
224 _skip_im2col(false),
225 _skip_col2im(false),
226 _is_quantized(false),
227 _is_prepared(false),
SiCong Lic5ab4df2023-10-17 17:38:57 +0100228 _wt_method(WeightTransformMethod::ReshapeThenTranspose),
229 _run_wt(true),
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100230 _aux_mem(AuxTensorIdx::Count)
Manuel Bottini29599d02021-07-06 15:01:35 +0100231{
232}
Georgios Pinitas19884632021-08-16 12:38:54 +0100233CpuGemmConv2d::~CpuGemmConv2d() = default;
Manuel Bottini29599d02021-07-06 15:01:35 +0100234
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100235void CpuGemmConv2d::configure_mm(const ITensorInfo *src,
236 const ITensorInfo *weights,
237 const ITensorInfo *biases,
238 ITensorInfo *dst,
239 const ActivationLayerInfo &act_info,
240 bool enable_fast_math,
241 int gemm_3d_depth,
242 bool fixed_format,
243 arm_compute::WeightFormat weight_format)
Manuel Bottini29599d02021-07-06 15:01:35 +0100244{
245 ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100246 ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth,
247 _skip_im2col, fixed_format, weight_format));
Manuel Bottini29599d02021-07-06 15:01:35 +0100248
Manuel Bottini29599d02021-07-06 15:01:35 +0100249 // Supported activations in GEMM
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100250 const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = {
251 ActivationLayerInfo::ActivationFunction::RELU, ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
252 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU};
Manuel Bottini29599d02021-07-06 15:01:35 +0100253
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100254 if (_is_quantized)
Manuel Bottini29599d02021-07-06 15:01:35 +0100255 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100256 TensorInfo tmp_src{*src};
257 TensorInfo tmp_weights{*weights};
Manuel Bottini29599d02021-07-06 15:01:35 +0100258 // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
259 // Extract and negate input and weights offset
260 const QuantizationInfo iqinfo = src->quantization_info();
261 const QuantizationInfo wqinfo = weights->quantization_info();
262 const QuantizationInfo oqinfo = (dst->total_size() == 0) ? iqinfo : dst->quantization_info();
263 const UniformQuantizationInfo uiqinfo = iqinfo.uniform();
264 const UniformQuantizationInfo uoqinfo = oqinfo.uniform();
265 const DataType data_type = src->data_type();
266
267 tmp_src.set_quantization_info(QuantizationInfo(uiqinfo.scale, -uiqinfo.offset));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100268 if (!is_data_type_quantized_per_channel(tmp_weights.data_type()))
Manuel Bottini29599d02021-07-06 15:01:35 +0100269 {
270 const UniformQuantizationInfo uwqinfo = wqinfo.uniform();
271 tmp_weights.set_quantization_info(QuantizationInfo(uwqinfo.scale, -uwqinfo.offset));
272 }
273
274 // Merge activation with output stage
275 PixelValue type_min{};
276 PixelValue type_max{};
277 std::tie(type_min, type_max) = get_min_max(data_type);
Renato Arantes57132942023-04-24 07:19:59 +0000278 int32_t min_activation = type_min.get<int32_t>();
279 int32_t max_activation = type_max.get<int32_t>();
Manuel Bottini29599d02021-07-06 15:01:35 +0100280
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100281 if (supported_acts.count(act_info.activation()) != 0)
Manuel Bottini29599d02021-07-06 15:01:35 +0100282 {
283 std::tie(min_activation, max_activation) = get_quantized_activation_min_max(act_info, data_type, uoqinfo);
284 }
285
286 GEMMLowpOutputStageInfo output_info;
287 output_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
288 output_info.gemmlowp_offset = uoqinfo.offset;
289 output_info.gemmlowp_min_bound = min_activation;
290 output_info.gemmlowp_max_bound = max_activation;
291 output_info.is_quantized_per_channel = (tmp_weights.data_type() == DataType::QSYMM8_PER_CHANNEL);
292 quantization::calculate_quantized_multipliers(iqinfo, wqinfo, oqinfo, output_info);
293
294 _mm_gemmlowp = std::make_unique<CpuGemmLowpMatrixMultiplyCore>();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100295 _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst,
296 GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false,
SiCong Lic5ab4df2023-10-17 17:38:57 +0100297 enable_fast_math, false, act_info, fixed_format, weight_format,
298 false /* pretranspose_B. TODO: COMPMID-6596 */));
Manuel Bottini29599d02021-07-06 15:01:35 +0100299
300 auto mm_mem_req = _mm_gemmlowp->workspace();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100301 for (unsigned int cont = 0; cont < mm_mem_req.size(); ++cont)
Manuel Bottini29599d02021-07-06 15:01:35 +0100302 {
303 _aux_mem[cont] = mm_mem_req[cont];
304 }
305 }
306 else
307 {
SiCong Lic5ab4df2023-10-17 17:38:57 +0100308 // Create GEMMInfo structure
309 const GEMMInfo &gemm_info =
310 GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth,
311 _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false,
312 GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, fixed_format, weight_format,
313 true /*pretranspose_B. For fp gemm (wt path 1 - 3), We always pretranspose B (for wt path 1 this
314 flag is ignored)*/);
Manuel Bottini29599d02021-07-06 15:01:35 +0100315 // Configure matrix multiply function
316 _mm_gemm = std::make_unique<CpuGemm>();
Viet-Hoa Do9b0a6b42023-04-03 16:27:25 +0100317 _mm_gemm->configure(src, weights, biases, dst, 1.0f, 1.0f, gemm_info);
Manuel Bottini29599d02021-07-06 15:01:35 +0100318 auto mm_mem_req = _mm_gemm->workspace();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100319 for (unsigned int cont = 0; cont < mm_mem_req.size(); ++cont)
Manuel Bottini29599d02021-07-06 15:01:35 +0100320 {
321 _aux_mem[cont] = mm_mem_req[cont];
322 }
323 }
324}
325
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100326Status CpuGemmConv2d::validate_mm(const ITensorInfo *src,
327 const ITensorInfo *weights,
328 const ITensorInfo *biases,
329 const ITensorInfo *dst,
330 const ActivationLayerInfo &act_info,
331 bool enable_fast_math,
332 int gemm_3d_depth,
333 bool skip_im2col,
334 bool fixed_format,
335 arm_compute::WeightFormat weight_format)
Manuel Bottini29599d02021-07-06 15:01:35 +0100336{
337 const DataType data_type = src->data_type();
338 const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
339 const bool is_activation_enabled = act_info.enabled();
340
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100341 if (is_quantized)
Manuel Bottini29599d02021-07-06 15:01:35 +0100342 {
343 // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
344 // Extract and negate input and weights offset
345 const QuantizationInfo &iqinfo = src->quantization_info();
346 const QuantizationInfo &wqinfo = weights->quantization_info();
347 const QuantizationInfo &oqinfo = (dst->total_size() == 0) ? iqinfo : dst->quantization_info();
348 const UniformQuantizationInfo uoqinfo = oqinfo.uniform();
349
350 // Merge activation with output stage
351 PixelValue type_min{};
352 PixelValue type_max{};
353 std::tie(type_min, type_max) = get_min_max(data_type);
Renato Arantes57132942023-04-24 07:19:59 +0000354 int32_t min_activation = type_min.get<int32_t>();
355 int32_t max_activation = type_max.get<int32_t>();
Manuel Bottini29599d02021-07-06 15:01:35 +0100356
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100357 const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = {
358 ActivationLayerInfo::ActivationFunction::RELU, ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
359 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU};
360 if (is_activation_enabled && supported_acts.count(act_info.activation()) != 0)
Manuel Bottini29599d02021-07-06 15:01:35 +0100361 {
362 std::tie(min_activation, max_activation) = get_quantized_activation_min_max(act_info, data_type, uoqinfo);
363 }
364
365 GEMMLowpOutputStageInfo output_info;
366 output_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
367 output_info.gemmlowp_offset = uoqinfo.offset;
368 output_info.gemmlowp_min_bound = min_activation;
369 output_info.gemmlowp_max_bound = max_activation;
370 output_info.is_quantized_per_channel = (weights->data_type() == DataType::QSYMM8_PER_CHANNEL);
371 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multipliers(iqinfo, wqinfo, oqinfo, output_info));
372
373 // Perform validation step on GEMMLowp
374 std::unique_ptr<ITensorInfo> input_qa = src->clone();
375 std::unique_ptr<ITensorInfo> weights_qa = weights->clone();
376 input_qa->set_quantization_info(QuantizationInfo(iqinfo.uniform().scale, -iqinfo.uniform().offset));
377 weights_qa->set_quantization_info(QuantizationInfo(wqinfo.uniform().scale, -wqinfo.uniform().offset));
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000378
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100379 return CpuGemmLowpMatrixMultiplyCore::validate(input_qa.get(), weights_qa.get(), biases, dst,
380 GEMMInfo(false, false, true, gemm_3d_depth, skip_im2col, false,
SiCong Lic5ab4df2023-10-17 17:38:57 +0100381 output_info, false, enable_fast_math, false, act_info,
382 false /* pretranspose_B. TODO: COMPMID-6596 */));
Manuel Bottini29599d02021-07-06 15:01:35 +0100383 }
384 else
385 {
SiCong Lic5ab4df2023-10-17 17:38:57 +0100386 // Create GEMMInfo structure
387 const GEMMInfo gemm_info =
388 GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth,
389 skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false,
390 GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, fixed_format, weight_format,
391 true /*pretranspose_B. For fp gemm (wt path 1 - 3), We always pretranspose B (for wt path 1 this
392 flag is ignored)*/);
393
Manuel Bottini29599d02021-07-06 15:01:35 +0100394 // Perform validation step on Matrix multiply function
Viet-Hoa Do9b0a6b42023-04-03 16:27:25 +0100395 return CpuGemm::validate(src, weights, biases, dst, 1.0f, 1.0f, gemm_info);
Manuel Bottini29599d02021-07-06 15:01:35 +0100396 }
397}
398
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100399Status CpuGemmConv2d::validate_gemm3d(const ITensorInfo *input_info,
400 const ITensorInfo *weights_info,
401 const ActivationLayerInfo &act_info,
402 int gemm_3d_depth,
403 bool skip_im2col)
Manuel Bottini29599d02021-07-06 15:01:35 +0100404{
405 const DataType data_type = input_info->data_type();
406 const unsigned int mult_y = skip_im2col ? 1U : gemm_3d_depth;
407 const unsigned int mult_z = skip_im2col ? gemm_3d_depth : 1U;
408
409 // Set dummy tensor shapes for the validation
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100410 const TensorInfo dummy_input_info(TensorShape(4U, 4U * mult_y, 1U * mult_z), 1, data_type,
411 input_info->quantization_info());
Manuel Bottini29599d02021-07-06 15:01:35 +0100412 const TensorInfo dummy_weights_info(TensorShape(4U, 4U), 1, data_type, weights_info->quantization_info());
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100413 const TensorInfo dummy_output_info(TensorShape(4U, 4U, gemm_3d_depth), 1, data_type,
414 input_info->quantization_info());
Manuel Bottini29599d02021-07-06 15:01:35 +0100415
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100416 return validate_mm(&dummy_input_info, &dummy_weights_info, nullptr, &dummy_output_info, act_info, false,
417 gemm_3d_depth, skip_im2col);
Manuel Bottini29599d02021-07-06 15:01:35 +0100418}
419
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100420void CpuGemmConv2d::configure(const ITensorInfo *src,
421 const ITensorInfo *weights,
422 const ITensorInfo *biases,
423 ITensorInfo *dst,
424 const PadStrideInfo &conv_info,
425 const WeightsInfo &weights_info,
426 const Size2D &dilation,
427 const ActivationLayerInfo &act_info,
428 bool enable_fast_math,
429 unsigned int num_groups)
Manuel Bottini29599d02021-07-06 15:01:35 +0100430{
431 ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst);
432 ARM_COMPUTE_UNUSED(num_groups, weights_info);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100433 ARM_COMPUTE_ERROR_THROW_ON(CpuGemmConv2d::validate(src, weights, biases, dst, conv_info, weights_info, dilation,
434 act_info, enable_fast_math, num_groups));
435 ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, conv_info, weights_info, dilation, act_info, enable_fast_math,
436 num_groups);
Manuel Bottini29599d02021-07-06 15:01:35 +0100437
438 const DataType data_type = src->data_type();
439 const DataLayout data_layout = src->data_layout();
440 const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
441 const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
Renato Arantes57132942023-04-24 07:19:59 +0000442 const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
Manuel Bottini29599d02021-07-06 15:01:35 +0100443 const int idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
444
445 const unsigned int kernel_width = weights->dimension(idx_width);
446 const unsigned int kernel_height = weights->dimension(idx_height);
447
448 _is_prepared = weights_info.retain_internal_weights();
449 _is_quantized = is_data_type_quantized_asymmetric(src->data_type());
450 _data_layout = data_layout;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100451 _skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 &&
452 conv_info.stride().first == 1 && conv_info.stride().second == 1);
Manuel Bottini29599d02021-07-06 15:01:35 +0100453
454 const ITensorInfo *gemm_input_to_use = src;
455 ITensorInfo *gemm_output_to_use = dst;
456
457 // Get convolved dimensions
Renato Arantes57132942023-04-24 07:19:59 +0000458 unsigned int conv_w = 0;
459 unsigned int conv_h = 0;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100460 std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width), src->dimension(idx_height), kernel_width,
461 kernel_height, conv_info, dilation);
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000462
Manuel Bottini29599d02021-07-06 15:01:35 +0100463 ARM_COMPUTE_ERROR_ON_MSG((dst->dimension(idx_width) != conv_w) || (dst->dimension(idx_height) != conv_h),
464 "Output shape does not match the expected one");
465
466 // Check if GEMM3D is supported
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100467 const CpuGemmConv2d::SkipInfo skip_info =
468 CpuGemmConv2d::skip_im_col_info(src, weights, conv_info, dilation, act_info);
469 _skip_im2col = skip_info.skip_im2col;
470 _skip_col2im = skip_info.skip_col2im;
Manuel Bottini29599d02021-07-06 15:01:35 +0100471
472 // Get parameters from conv_info
Renato Arantes57132942023-04-24 07:19:59 +0000473 unsigned int stride_x = 0;
474 unsigned int stride_y = 0;
Manuel Bottini29599d02021-07-06 15:01:35 +0100475 std::tie(stride_x, stride_y) = conv_info.stride();
476
SiCong Lic5ab4df2023-10-17 17:38:57 +0100477 // Initialize reshaped weights
478 initialize_reshaped_weight_info(*weights, _weights_reshaped);
Manuel Bottini29599d02021-07-06 15:01:35 +0100479
480 // Create tensor to store im2col reshaped inputs
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100481 if (!_skip_im2col)
Manuel Bottini29599d02021-07-06 15:01:35 +0100482 {
Renato Arantes57132942023-04-24 07:19:59 +0000483 const int block_by = arm_compute::block_by(weights_info.weight_format());
484 unsigned int input_pad_right = 0;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100485 if (block_by > 1)
Renato Arantes57132942023-04-24 07:19:59 +0000486 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100487 input_pad_right =
488 (src->dimension(idx_channel) % block_by) == 0 ? 0 : block_by - (src->dimension(idx_channel) % block_by);
Renato Arantes57132942023-04-24 07:19:59 +0000489 }
Manuel Bottini29599d02021-07-06 15:01:35 +0100490 // Configure
491 _im2col_kernel = std::make_unique<kernels::CpuIm2ColKernel>();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100492 _im2col_kernel->configure(src, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, false, dilation,
493 num_groups, input_pad_right);
Manuel Bottini29599d02021-07-06 15:01:35 +0100494
495 // Update GEMM input
496 gemm_input_to_use = &_im2col_output;
497 }
498
SiCong Lic5ab4df2023-10-17 17:38:57 +0100499 const unsigned int mat_weights_cols = weights->dimension(idx_kernels);
500
Manuel Bottini29599d02021-07-06 15:01:35 +0100501 // Create temporary GEMM output tensor in case we cannot skip col2im
502 const DataType output_data_type = data_type == DataType::BFLOAT16 ? DataType::F32 : data_type;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100503 if (!_skip_col2im)
Manuel Bottini29599d02021-07-06 15:01:35 +0100504 {
505 TensorShape shape_gemm;
506
507 // Calculate GEMM output shape
508 shape_gemm = _im2col_output.tensor_shape();
509 shape_gemm.set(0, mat_weights_cols);
510 shape_gemm.set(1, conv_w * conv_h);
511
512 _gemm_output = TensorInfo(shape_gemm, 1, output_data_type);
513 _gemm_output.set_quantization_info(dst->quantization_info()).set_data_layout(src->data_layout());
514 _gemm_output_3d = TensorInfo(_gemm_output);
515
516 // Update GEMM output
517 gemm_output_to_use = &_gemm_output;
518 }
519 else
520 {
521 _gemm_output_3d = TensorInfo(*dst);
522 _gemm_output_3d.set_data_type(output_data_type).set_data_layout(src->data_layout()).set_is_resizable(true);
523 _gemm_output = TensorInfo(_gemm_output_3d);
524
525 // Update GEMM output
526 gemm_output_to_use = &_gemm_output_3d;
527 }
528
529 // Configure GEMM
530 // In case we need to skip col2im, GEMM3D (gemm_3d_depth != 0) must be called in order to avoid reshaping the output matrix
531 const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0;
Ramy Elgammal91780022022-07-20 14:57:37 +0100532 const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
SiCong Lic5ab4df2023-10-17 17:38:57 +0100533 /** @section note_CpuGemmConv2d_weight_use_in_configure Which weights tensor should we use to configure gemm
534 *
535 * A. The problem:
536 * In principle, we should use the weights tensor corresponding to the weights transformation path. I.e.:
537 * - If no weight transformation (_run_wt == false): Use original weights
538 * - else: Use transformed weights
539 * However in practice we have a dilemma:
540 * - We need to know _run_wt before we can configure gemm with the corresponding weights, but
541 * - _run_wt depends on isVarWeightsKernel(), which is only known after gemm is configured
542 *
543 * B. The decision:
544 * To simplify the matter, we decide to always use the transformed weights, regardless of _run_wt
545 *
546 * This decision requires the following conditions:
547 * 1. The underlying gemm where isVarWeightsKernel() == true, must guarantee that:
548 * A. Ignore the flag to transpose weights (GEMMInfo::pretranspose_B)
549 * B. Use weights/B tensor passed to it at prepare() or run() instead of that passed at configure()
550 * 2. CpuGemmConv2d where isVarWeightsKernel() == true, must guarantee that:
551 * A. Pass original weights instead of reshaped or reinterpreted weights
552 *
553 * C. Future actions:
554 * Condition 2 is a given, based on our implementation.
555 * If condition 1 cannot hold, we must make changes to the underlying gemm to:
556 * 1. Either expose isVarWeightsKernel() before gemm is configured somehow, or
557 * 2. Take in an additional "original_weights" tensor info at configure
558 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100559 configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math,
560 gemm_3d_depth, fixed_format, weights_info.weight_format());
Manuel Bottini29599d02021-07-06 15:01:35 +0100561
SiCong Lic5ab4df2023-10-17 17:38:57 +0100562 // Can only decide isVarWeightsKernel after gemm is configured
563 _run_wt = !isVarWeightsKernel();
564
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100565 if (!_skip_col2im && _data_layout == DataLayout::NCHW)
Manuel Bottini29599d02021-07-06 15:01:35 +0100566 {
567 // Configure col2im
568 _col2im_kernel = std::make_unique<kernels::CpuCol2ImKernel>();
569 _col2im_kernel->configure(gemm_output_to_use, dst, Size2D(conv_w, conv_h));
570 }
571 else
572 {
573 // Configure reshape layer
Anitha Raj082630b2023-08-22 15:46:27 +0100574 _reshape = std::make_unique<CpuReshape>();
575 _reshape->configure(gemm_output_to_use, dst);
Manuel Bottini29599d02021-07-06 15:01:35 +0100576 }
577
Georgios Pinitasd4a5bc52021-08-12 07:42:51 +0100578 // Check lifetime
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100579 _aux_mem[Im2ColOutput] =
580 MemoryInfo(offset_int_vec(Im2ColOutput), MemoryLifetime::Temporary, _im2col_output.total_size());
SiCong Lic5ab4df2023-10-17 17:38:57 +0100581 // Add WeightsReshaped memory requirement to workspace
582 // Note that in case of WeightTransformMethod::ReinterpretThenTranspose, we do not need to allocate this memory
583 // However since we cannot determine weight transformation method until prepare (see prepare()), we will have to
584 // settle with allocating more
585 if (_run_wt)
586 {
587 // Check if GEMM transforms weights
588 // If weight is further transformed by underlying gemm after ReshapeThenTranspose then we can free
589 // WeightsReshaped in prepare
590 // Otherwise WeightsReshaped is the final transformation of weights and needs to persist
591 bool gemm_trans_wei = _aux_mem[GemmAsmPretransposedRHS].size > 0;
592 gemm_trans_wei = _mm_gemm != nullptr ? _aux_mem[GemmTransposed1xWRHS].size > 0 : gemm_trans_wei;
593 gemm_trans_wei = _mm_gemmlowp != nullptr ? _aux_mem[GemmLowpTransposed1xWRHS].size > 0 : gemm_trans_wei;
594
595 _aux_mem[WeightsReshaped] = MemoryInfo(offset_int_vec(WeightsReshaped),
596 gemm_trans_wei ? MemoryLifetime::Prepare : MemoryLifetime::Persistent,
597 _weights_reshaped.total_size());
598 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100599 _aux_mem[GemmOutput] = MemoryInfo(offset_int_vec(GemmOutput), MemoryLifetime::Temporary, _gemm_output.total_size());
Manuel Bottini29599d02021-07-06 15:01:35 +0100600}
601
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100602Status CpuGemmConv2d::has_opt_impl(arm_compute::WeightFormat &expected_weight_format,
603 const ITensorInfo *src,
604 const ITensorInfo *weights,
605 const ITensorInfo *biases,
606 const ITensorInfo *dst,
607 const PadStrideInfo &conv_info,
608 const WeightsInfo &weights_info,
609 const Size2D &dilation,
610 const ActivationLayerInfo &act_info,
611 const bool enable_fast_math)
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000612{
613 const DataLayout data_layout = src->data_layout();
614 const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
615 const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
616 const unsigned int kernel_width = weights->dimension(idx_width);
617 const unsigned int kernel_height = weights->dimension(idx_height);
618 unsigned int conv_w = 0;
619 unsigned int conv_h = 0;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100620 std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width), src->dimension(idx_height), kernel_width,
621 kernel_height, conv_info, dilation);
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000622
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100623 const CpuGemmConv2d::SkipInfo skip_info =
624 CpuGemmConv2d::skip_im_col_info(src, weights, conv_info, dilation, act_info);
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000625
626 const bool skip_im2col = skip_info.skip_im2col;
627 const bool skip_col2im = skip_info.skip_col2im;
628 const unsigned int gemm_3d_depth = skip_col2im ? conv_h : 0;
Ramy Elgammal91780022022-07-20 14:57:37 +0100629 const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
SiCong Lic5ab4df2023-10-17 17:38:57 +0100630
631 /** @section note_CpuGemmConv2d_weight_use_in_has_opt_impl Which weights tensor should we use for has_opt_impl
632 *
633 * For the pretranspose_B flag, this shares a similar problem and thus the same decision as that of
634 * @ref note_CpuGemmConv2d_weight_use_in_configure
635 *
636 * But for the weights, we shall always use the original instead of reshaped weights here
637 */
638 const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth,
639 skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false,
640 GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info,
641 fixed_format, weights_info.weight_format(), true /* pretranspose_B */);
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000642
643 return CpuGemm::has_opt_impl(expected_weight_format, src, weights, biases, dst, gemm_info);
644}
645
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100646Status CpuGemmConv2d::validate(const ITensorInfo *src,
647 const ITensorInfo *weights,
648 const ITensorInfo *biases,
649 const ITensorInfo *dst,
650 const PadStrideInfo &conv_info,
651 const WeightsInfo &weights_info,
652 const Size2D &dilation,
653 const ActivationLayerInfo &act_info,
654 bool enable_fast_math,
655 unsigned int num_groups)
Manuel Bottini29599d02021-07-06 15:01:35 +0100656{
657 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst);
658 ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights_info.are_reshaped(), "Weights already reshaped are not supported!");
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100659 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
660 DataType::BFLOAT16, DataType::F16, DataType::F32);
661 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
662 DataType::QSYMM8_PER_CHANNEL, DataType::BFLOAT16,
663 DataType::F16, DataType::F32);
Jonathan Deakin464ed202023-01-12 11:41:14 +0000664
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100665 if (!is_fixed_format(weights_info.weight_format()))
Jonathan Deakin464ed202023-01-12 11:41:14 +0000666 {
667 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights);
668 }
669
Manuel Bottini29599d02021-07-06 15:01:35 +0100670 ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups > 1, "Grouping (num_groups != 1) is not supported");
671
672 const DataLayout data_layout = src->data_layout();
673 const DataType data_type = src->data_type();
674 const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
675 const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
676 const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
677 const int idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
678
679 const unsigned int kernel_width = weights->dimension(idx_width);
680 const unsigned int kernel_height = weights->dimension(idx_height);
681
682 TensorInfo im2col_reshaped_info{};
683 TensorInfo info_gemm{};
684 TensorInfo tmp_info{};
685 TensorInfo weights_reshaped_info{};
686 const ITensorInfo *gemm_input_to_use = src;
687 const ITensorInfo *gemm_output_to_use = dst;
688 const ITensorInfo *weights_to_use = weights;
689
690 const bool append_bias = false;
691 const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
692 const bool is_bf16 = data_type == DataType::BFLOAT16;
Manuel Bottini29599d02021-07-06 15:01:35 +0100693
694 // Get convolved dimensions
695 unsigned int conv_w = 0;
696 unsigned int conv_h = 0;
697
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100698 std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width), src->dimension(idx_height), kernel_width,
699 kernel_height, conv_info, dilation);
Manuel Bottini29599d02021-07-06 15:01:35 +0100700
701 // Check if GEMM3D is supported
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100702 const CpuGemmConv2d::SkipInfo skip_info =
703 CpuGemmConv2d::skip_im_col_info(src, weights, conv_info, dilation, act_info);
704 const bool skip_im2col = skip_info.skip_im2col, skip_col2im = skip_info.skip_col2im;
Manuel Bottini29599d02021-07-06 15:01:35 +0100705
706 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != src->dimension(idx_channel));
707 ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
708
709 // Validate biases
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100710 if (biases != nullptr)
Manuel Bottini29599d02021-07-06 15:01:35 +0100711 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100712 if (is_quantized)
Manuel Bottini29599d02021-07-06 15:01:35 +0100713 {
714 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32);
715 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100716 else if (is_bf16)
Manuel Bottini29599d02021-07-06 15:01:35 +0100717 {
718 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::F32);
719 }
720 else
721 {
722 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, biases);
723 }
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000724 ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != dst->dimension(idx_channel));
Manuel Bottini29599d02021-07-06 15:01:35 +0100725 ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
726 }
727
728 unsigned int mat_weights_cols = weights->dimension(idx_kernels);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100729 unsigned int mat_weights_rows =
730 weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel);
Manuel Bottini29599d02021-07-06 15:01:35 +0100731
SiCong Lic5ab4df2023-10-17 17:38:57 +0100732 // Initialize reshaped weights
733 initialize_reshaped_weight_info(*weights, weights_reshaped_info);
734 // No need to call CpuReshape::validate() or CpuTranspose::validate() as the dst info is auto-configured from the
735 // src
Manuel Bottini29599d02021-07-06 15:01:35 +0100736 weights_to_use = &weights_reshaped_info;
737
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100738 if (!skip_im2col)
Manuel Bottini29599d02021-07-06 15:01:35 +0100739 {
Renato Arantes57132942023-04-24 07:19:59 +0000740 const int block_by = arm_compute::block_by(weights_info.weight_format());
741 int input_pad_right = 0;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100742 if (block_by > 1)
Renato Arantes57132942023-04-24 07:19:59 +0000743 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100744 input_pad_right =
745 (src->dimension(idx_channel) % block_by) == 0 ? 0 : block_by - (src->dimension(idx_channel) % block_by);
746 mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) *
747 (weights->dimension(idx_channel) + input_pad_right);
Renato Arantes57132942023-04-24 07:19:59 +0000748 }
749
Manuel Bottini29599d02021-07-06 15:01:35 +0100750 // Create tensor info for im2col reshaped inputs
751 // For CPU, the batch size is on the fourth dimension
752 TensorShape shape_im2col = src->tensor_shape();
753 shape_im2col.set(0, mat_weights_rows);
754 shape_im2col.set(1, conv_w * conv_h);
755 shape_im2col.set(2, 1);
756
757 im2col_reshaped_info = TensorInfo(shape_im2col, 1, data_type);
758 im2col_reshaped_info.set_quantization_info(src->quantization_info());
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100759 ARM_COMPUTE_RETURN_ON_ERROR(
760 kernels::CpuIm2ColKernel::validate(src, &im2col_reshaped_info, Size2D(kernel_width, kernel_height),
761 conv_info, append_bias, dilation, num_groups, input_pad_right));
Manuel Bottini29599d02021-07-06 15:01:35 +0100762 gemm_input_to_use = &im2col_reshaped_info;
763 }
764
765 // Create temporary GEMM output tensor in case we cannot skip col2im
766 const DataType output_data_type = data_type == DataType::BFLOAT16 ? DataType::F32 : data_type;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100767 if (!skip_col2im)
Manuel Bottini29599d02021-07-06 15:01:35 +0100768 {
769 TensorShape shape_gemm = gemm_input_to_use->tensor_shape();
770 shape_gemm.set(0, mat_weights_cols);
771 shape_gemm.set(1, conv_w * conv_h);
772 info_gemm = TensorInfo(shape_gemm, 1, output_data_type);
773 }
774 else
775 {
776 info_gemm = TensorInfo(dst->tensor_shape(), 1, output_data_type);
777 }
778 info_gemm.set_quantization_info(dst->quantization_info()).set_data_layout(src->data_layout());
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000779 gemm_output_to_use = &info_gemm;
Ramy Elgammal91780022022-07-20 14:57:37 +0100780 const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000781
SiCong Lic5ab4df2023-10-17 17:38:57 +0100782 // See note_CpuGemmConv2d_weight_use_in_configure regarding the choice of the weights
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100783 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info,
784 enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col, fixed_format,
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000785 weights_info.weight_format()));
Manuel Bottini29599d02021-07-06 15:01:35 +0100786
787 // Validate Col2Im/ReshapeLayer
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100788 if (!skip_col2im && (data_layout == DataLayout::NCHW))
Manuel Bottini29599d02021-07-06 15:01:35 +0100789 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100790 ARM_COMPUTE_RETURN_ON_ERROR(
791 kernels::CpuCol2ImKernel::validate(gemm_output_to_use, dst, Size2D(conv_w, conv_h)));
Manuel Bottini29599d02021-07-06 15:01:35 +0100792 }
793
794 return Status{};
795}
796
Georgios Pinitas19884632021-08-16 12:38:54 +0100797void CpuGemmConv2d::run(ITensorPack &tensors)
Manuel Bottini29599d02021-07-06 15:01:35 +0100798{
799 prepare(tensors);
800
801 auto src = tensors.get_const_tensor(ACL_SRC_0);
Manuel Bottini29599d02021-07-06 15:01:35 +0100802 auto dst = tensors.get_tensor(ACL_DST);
803 auto gemm_input_to_use = src;
804
805 CpuAuxTensorHandler im2col_output(offset_int_vec(Im2ColOutput), _im2col_output, tensors, false);
806 CpuAuxTensorHandler gemm_output(offset_int_vec(GemmOutput), _gemm_output, tensors, false);
807
808 bool out_has_padding = _skip_col2im && (dst->info()->padding().bottom != 0 || dst->info()->padding().top != 0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100809 if (!_skip_im2col)
Manuel Bottini29599d02021-07-06 15:01:35 +0100810 {
811 // Run input reshaping
Milos Puzovic1e91d712024-03-28 13:28:21 +0000812 unsigned int hint_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
813 unsigned int x_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH);
814 unsigned int hint_dim_iterations = _im2col_kernel->window().num_iterations(hint_dim);
815 unsigned int x_dim_iterations = _im2col_kernel->window().num_iterations(x_dim);
816 if (hint_dim_iterations < NEScheduler::get().num_threads() && x_dim_iterations > hint_dim_iterations)
817 {
818 hint_dim = x_dim;
819 }
820 ITensorPack pack = {{TensorType::ACL_SRC, src}, {TensorType::ACL_DST, im2col_output.get()}};
821 NEScheduler::get().schedule_op(_im2col_kernel.get(), hint_dim, _im2col_kernel->window(), pack);
Manuel Bottini29599d02021-07-06 15:01:35 +0100822 gemm_input_to_use = im2col_output.get();
823 }
824
825 // Handle the case where output has top/bottom padding
826 const ITensor *out_to_use = out_has_padding ? gemm_output.get() : dst;
Georgios Pinitasd4a5bc52021-08-12 07:42:51 +0100827 Tensor gemm3d;
Manuel Bottini29599d02021-07-06 15:01:35 +0100828 _gemm_output_3d.extend_padding(out_to_use->info()->padding());
Georgios Pinitasd4a5bc52021-08-12 07:42:51 +0100829 gemm3d.allocator()->soft_init(_gemm_output_3d);
830 gemm3d.allocator()->import_memory(out_to_use->buffer());
831 auto gemm_output_to_use = gemm_output.get();
832
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100833 if (_skip_im2col)
Manuel Bottini29599d02021-07-06 15:01:35 +0100834 {
Georgios Pinitasd4a5bc52021-08-12 07:42:51 +0100835 gemm_output_to_use = &gemm3d;
Manuel Bottini29599d02021-07-06 15:01:35 +0100836 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100837 if (_skip_col2im && !out_has_padding)
Manuel Bottini29599d02021-07-06 15:01:35 +0100838 {
839 gemm_output_to_use = dst;
840 }
841
SiCong Lic5ab4df2023-10-17 17:38:57 +0100842 ITensorPack gemm_pack = tensors;
843 gemm_pack.add_const_tensor(TensorType::ACL_SRC_0, gemm_input_to_use);
844 gemm_pack.add_tensor(TensorType::ACL_DST, gemm_output_to_use);
845 // Allocate reshaped weights if required
SiCong Li24c140f2023-11-10 12:16:32 +0000846 auto weights = gemm_pack.get_const_tensor(TensorType::ACL_SRC_1);
847 ARM_COMPUTE_ERROR_ON_NULLPTR(weights);
848 // Re-interpreted weights. Only tensor shape is changed. Only memory import, no allocation
Gunes Bayirbf053732024-03-04 14:55:24 +0000849 const bool use_reinterpreted_wei = (_run_wt && _wt_method == WeightTransformMethod::ReinterpretThenTranspose);
SiCong Lic5ab4df2023-10-17 17:38:57 +0100850 CpuAuxTensorHandler reinterpreted_wei(
SiCong Li24c140f2023-11-10 12:16:32 +0000851 _weights_reshaped, *weights,
852 /* import only if we chose the ReinterpretThenTranspose path, because otherwise the weight may have been freed */
Gunes Bayirbf053732024-03-04 14:55:24 +0000853 !use_reinterpreted_wei);
854
855 const bool use_reshaped_wei = (_run_wt && (_wt_method == WeightTransformMethod::ReshapeThenTranspose ||
856 _wt_method == WeightTransformMethod::FusedReshapeAndTranspose));
857 CpuAuxTensorHandler reshaped_wei(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors,
858 false /* pack_inject */, !use_reshaped_wei /* bypass_alloc */,
859 !use_reshaped_wei /* bypass_import */
860 );
SiCong Lic5ab4df2023-10-17 17:38:57 +0100861 // Update the weights to use if it has been reshaped
Gunes Bayirbf053732024-03-04 14:55:24 +0000862 if (use_reinterpreted_wei)
SiCong Lic5ab4df2023-10-17 17:38:57 +0100863 {
Gunes Bayirbf053732024-03-04 14:55:24 +0000864 gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reinterpreted_wei.get());
865 }
866 else if (use_reshaped_wei)
867 {
868 gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
SiCong Lic5ab4df2023-10-17 17:38:57 +0100869 }
870
Manuel Bottini29599d02021-07-06 15:01:35 +0100871 // Runs CpuGemm or CpuGemmLowpMatrixMultiplyCore functions
SiCong Lic5ab4df2023-10-17 17:38:57 +0100872 _is_quantized ? _mm_gemmlowp->run(gemm_pack) : _mm_gemm->run(gemm_pack);
Manuel Bottini29599d02021-07-06 15:01:35 +0100873
874 // Reshape output matrix
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100875 if (!_skip_col2im)
Manuel Bottini29599d02021-07-06 15:01:35 +0100876 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100877 if (_data_layout == DataLayout::NCHW)
Manuel Bottini29599d02021-07-06 15:01:35 +0100878 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100879 ITensorPack pack = {{TensorType::ACL_SRC, gemm_output.get()}, {TensorType::ACL_DST, dst}};
Manuel Bottini29599d02021-07-06 15:01:35 +0100880 NEScheduler::get().schedule_op(_col2im_kernel.get(), Window::DimY, _col2im_kernel->window(), pack);
881 }
882 else
883 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100884 ITensorPack pack = {{TensorType::ACL_SRC, gemm_output_to_use}, {TensorType::ACL_DST, dst}};
Anitha Raj082630b2023-08-22 15:46:27 +0100885 _reshape->run(pack);
Manuel Bottini29599d02021-07-06 15:01:35 +0100886 }
887 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100888 else if (out_has_padding)
Manuel Bottini29599d02021-07-06 15:01:35 +0100889 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100890 ITensorPack pack = {{TensorType::ACL_SRC, gemm_output_to_use}, {TensorType::ACL_DST, dst}};
Anitha Raj082630b2023-08-22 15:46:27 +0100891 _reshape->run(pack);
Manuel Bottini29599d02021-07-06 15:01:35 +0100892 }
893}
894
Georgios Pinitas19884632021-08-16 12:38:54 +0100895void CpuGemmConv2d::prepare(ITensorPack &tensors)
Manuel Bottini29599d02021-07-06 15:01:35 +0100896{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100897 if (!_is_prepared)
Manuel Bottini29599d02021-07-06 15:01:35 +0100898 {
SiCong Lic5ab4df2023-10-17 17:38:57 +0100899 auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
900 // Determine which weights reshape path to take
901 // Note that this decision can only occur at prepare instead of configure because it relies on the presence of
902 // any holes in the weight tensor, which may change after configure (e.g. from extending padding)
903 if (_run_wt)
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000904 {
SiCong Lic5ab4df2023-10-17 17:38:57 +0100905 _wt_method = get_wt_method(*(weights->info()));
906 switch (_wt_method)
907 {
908 case (WeightTransformMethod::FusedReshapeAndTranspose):
909 {
910 ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Perform weight transformation: FusedReshapeAndTranspose");
911 _weights_reshape_and_transpose_kernel = std::make_unique<kernels::CpuWeightsReshapeKernel>();
912 _weights_reshape_and_transpose_kernel->configure(weights->info(), nullptr, &_weights_reshaped);
913 break;
914 }
915 case (WeightTransformMethod::ReshapeThenTranspose):
916 {
917 ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Perform weight transformation: ReshapeThenTranspose");
918 _weights_reshape = std::make_unique<CpuReshape>();
919 _weights_reshape->configure(weights->info(), &_weights_reshaped);
920 break;
921 }
922 case (WeightTransformMethod::ReinterpretThenTranspose):
923 {
924 ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Perform weight transformation: ReinterpretThenTranspose");
925 // Nothing to configure
926 break;
927 }
928 default:
929 {
930 ARM_COMPUTE_ERROR("Unsupported weight transform method");
931 }
932 }
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000933 }
SiCong Lic5ab4df2023-10-17 17:38:57 +0100934 else
935 {
936 ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("No weight transformation is performed");
937 }
Georgios Pinitasd4a5bc52021-08-12 07:42:51 +0100938 ITensorPack gemm_pack = tensors;
SiCong Lic5ab4df2023-10-17 17:38:57 +0100939 // Allocate reshaped weights if required
940 CpuAuxTensorHandler reinterpreted_wei(
941 _weights_reshaped,
942 *weights); // Re-interpreted weights. Only tensor shape is changed. No allocation
943 CpuAuxTensorHandler reshaped_wei(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors);
944 // Run weights reshape if required
945 if (_run_wt)
946 {
947 switch (_wt_method)
948 {
949 case (WeightTransformMethod::FusedReshapeAndTranspose):
950 {
951 ITensorPack pack = {{TensorType::ACL_SRC, weights}, {TensorType::ACL_DST, reshaped_wei.get()}};
952 NEScheduler::get().schedule_op(_weights_reshape_and_transpose_kernel.get(), Window::DimW,
953 _weights_reshape_and_transpose_kernel->window(), pack);
954 weights->mark_as_unused();
955 gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
956 break;
957 }
958 case (WeightTransformMethod::ReshapeThenTranspose):
959 {
960 ITensorPack pack = {{TensorType::ACL_SRC, weights}, {TensorType::ACL_DST, reshaped_wei.get()}};
961 _weights_reshape->run(pack);
962 weights->mark_as_unused();
963 gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
964 break;
965 }
966 case (WeightTransformMethod::ReinterpretThenTranspose):
967 {
968 gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reinterpreted_wei.get());
969 // Nothing to run
970 break;
971 }
972 default:
973 {
974 ARM_COMPUTE_ERROR("Unsupported weight transform method");
975 }
976 }
977 }
Georgios Pinitasd4a5bc52021-08-12 07:42:51 +0100978 _is_quantized ? _mm_gemmlowp->prepare(gemm_pack) : _mm_gemm->prepare(gemm_pack);
SiCong Lic5ab4df2023-10-17 17:38:57 +0100979
Manuel Bottini29599d02021-07-06 15:01:35 +0100980 _is_prepared = true;
981 }
982}
Georgios Pinitas19884632021-08-16 12:38:54 +0100983experimental::MemoryRequirements CpuGemmConv2d::workspace() const
Manuel Bottini29599d02021-07-06 15:01:35 +0100984{
985 return _aux_mem;
986}
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000987bool CpuGemmConv2d::isVarWeightsKernel() const
988{
989 return _mm_gemm && _mm_gemm->isVarWeightsKernel();
990}
Manuel Bottini29599d02021-07-06 15:01:35 +0100991} // namespace cpu
Francesco.Petrogalli@arm.comfa6877f2022-04-13 09:28:25 +0000992} // namespace arm_compute