blob: e035de01316633dfaaf9519b47f39e4143d05c81 [file] [log] [blame]
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +01001/*
Jonathan Deakin464ed202023-01-12 11:41:14 +00002 * Copyright (c) 2021-2023 Arm Limited.
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitas7891a732021-08-20 21:39:25 +010024#include "src/cpu/operators/CpuGemm.h"
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +010025
26#include "arm_compute/core/TensorInfo.h"
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +010027#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010028#include "arm_compute/core/Validate.h"
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +010029#include "arm_compute/runtime/NEON/NEScheduler.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010030
ramelg013ae3d882021-09-12 23:07:47 +010031#include "src/common/utils/Log.h"
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +010032#include "src/core/CPP/Validate.h"
33#include "src/core/helpers/AutoConfiguration.h"
34#include "src/core/helpers/MemoryHelpers.h"
Georgios Pinitas7891a732021-08-20 21:39:25 +010035#include "src/cpu/utils/CpuAuxTensorHandler.h"
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +010036
37using namespace arm_compute::experimental;
38using namespace arm_compute::misc::shape_calculator;
39
40namespace arm_compute
41{
42namespace cpu
43{
44namespace
45{
46cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
47{
48 cpu::AsmGemmInfo asm_info;
49 asm_info.method = cpu::AsmConvMethod::Im2Col;
50 asm_info.reinterpret_input_as_3d = info.reinterpret_input_as_3d();
51 asm_info.depth_output_gemm3d = info.depth_output_gemm3d();
52 asm_info.activation_info = info.activation_info();
Georgios Pinitas4ee8b152021-07-16 16:16:43 +010053 asm_info.fast_mode = info.fast_math();
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +000054 asm_info.fixed_format = info.fixed_format();
Francesco Petrogalli553f6952022-06-30 10:22:01 +000055 asm_info.weight_format = info.weight_format();
SiCong Lic5ab4df2023-10-17 17:38:57 +010056 asm_info.transpose_b =
57 info.pretranspose_B(); // The "pretranspose_B" flag here is not the same as the pretranspose_B_array method. The flag here signals to pretranspose_B_array method if we want to perform additional transpose on B before the pretranspose_B_array method
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +010058
59 return asm_info;
60}
61} // namespace
62
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010063void CpuGemm::configure(const ITensorInfo *a,
64 const ITensorInfo *b,
65 const ITensorInfo *c,
66 ITensorInfo *d,
67 float alpha,
68 float beta,
69 const GEMMInfo &gemm_info)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +010070{
71 ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
72 ARM_COMPUTE_ERROR_THROW_ON(CpuGemm::validate(a, b, c, d, alpha, beta, gemm_info));
ramelg013ae3d882021-09-12 23:07:47 +010073 ARM_COMPUTE_LOG_PARAMS(a, b, c, d, alpha, beta, gemm_info);
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +010074
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010075 const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
76 const bool is_c_bias = beta == 1 && c != nullptr;
SiCong Lic5ab4df2023-10-17 17:38:57 +010077 const bool run_optimised =
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010078 bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info)) &&
79 (c == nullptr || beta == 0.f || beta == 1.f) && // Optimized GeMM doesn't support beta coefficient.
80 !(!b->are_values_constant() &&
81 b->tensor_shape().z() > 1); // Disable batch matmul as optimized GeMM handles batching differently.
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +010082
83 // Check if we need to reshape the matrix B only on the first run
84 _is_prepared = false;
Viet-Hoa Do9b0a6b42023-04-03 16:27:25 +010085 _reshape_b_only_on_first_run = b->are_values_constant();
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +010086 _run_vector_matrix_multiplication = a->dimension(1) < 2;
87 _run_alpha_scale = alpha != 1.f;
Viet-Hoa Doa3e57c22023-03-13 16:20:04 +000088 _run_bias_addition = is_c_bias;
89 _run_addition = beta != 0 && beta != 1 && c != nullptr;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010090 _run_activation =
91 gemm_info.activation_info().enabled() &&
92 (!run_optimised ||
93 (run_optimised && !cpu::CpuGemmAssemblyDispatch::is_activation_supported(gemm_info.activation_info())));
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +010094
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010095 if (run_optimised)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +010096 {
SiCong Lic5ab4df2023-10-17 17:38:57 +010097 _run_interleave_transpose = false;
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +010098 const ITensorInfo *c_to_use = is_c_bias ? c : nullptr;
99 _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
100 _asm_glue->configure(a, b, c_to_use, d, asm_info);
101 ARM_COMPUTE_ERROR_ON(!_asm_glue->is_configured());
102
SiCong Lic5ab4df2023-10-17 17:38:57 +0100103 const auto asm_mem_req = _asm_glue->workspace();
104 for (unsigned int slot = 0; slot < asm_mem_req.size(); ++slot)
105 {
106 _aux_mem[slot] = asm_mem_req[slot];
107 }
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100108
109 // Scale product by alpha
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100110 if (_run_alpha_scale)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100111 {
112 _alpha_scale_func = std::make_unique<cpu::CpuActivation>();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100113 _alpha_scale_func->configure(
114 d, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LINEAR, alpha, 0.f));
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100115 }
116 }
117 else
118 {
SiCong Lic5ab4df2023-10-17 17:38:57 +0100119 _run_interleave_transpose = !_run_vector_matrix_multiplication;
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100120 // Pick output tensor in case bias addition should be performed
121 ITensorInfo *gemm_output_to_use = (_run_bias_addition) ? &_tmp_d : d;
SiCong Lic5ab4df2023-10-17 17:38:57 +0100122 // Pick b tensor in case pretranspose should be performed
123 const ITensorInfo *b_to_use = b;
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100124
125 _mm_kernel = std::make_unique<cpu::kernels::CpuGemmMatrixMultiplyKernel>();
126
SiCong Lic5ab4df2023-10-17 17:38:57 +0100127 // Configure rhs pretranspose
128 if (gemm_info.pretranspose_B())
129 {
130 _pretranspose_b_func = std::make_unique<CpuTranspose>();
131 _pretranspose_b_func->configure(b_to_use, &_pretransposed_b);
132 MemoryLifetime lifetime;
133 if (_reshape_b_only_on_first_run)
134 {
135 if (_run_interleave_transpose)
136 {
137 // PreTransposedRHS tensor is only used in prepare(), but is then succeeded by Transposed1xWRHS
138 // So PreTransposedRHS can be freed inside prepare()
139 lifetime = MemoryLifetime::Prepare;
140 }
141 else
142 {
143 // PreTransposedRHS tensor is only used in prepare(), but is the final transformation of rhs
144 // So PreTransposedRHS needs to persist beyond prepare()
145 lifetime = MemoryLifetime::Persistent;
146 }
147 }
148 else
149 {
150 // PreTransposedRHS tensor is always used in run() and doesn't need to persist
151 lifetime = MemoryLifetime::Temporary;
152 }
153 _aux_mem[PreTransposedRHS] =
154 MemoryInfo(offset_int_vec(PreTransposedRHS), lifetime, _pretransposed_b.total_size());
155 b_to_use = &_pretransposed_b;
156 }
157
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100158 // Select between GEMV and GEMM
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100159 if (_run_vector_matrix_multiplication)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100160 {
161 // Configure the matrix multiply kernel
SiCong Lic5ab4df2023-10-17 17:38:57 +0100162 _mm_kernel->configure(a, b_to_use, gemm_output_to_use, alpha, false);
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100163 }
164 else
165 {
SiCong Lic5ab4df2023-10-17 17:38:57 +0100166 ARM_COMPUTE_ERROR_ON(!_run_interleave_transpose);
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100167 // Configure interleave kernel
168 _interleave_kernel = std::make_unique<cpu::kernels::CpuGemmInterleave4x4Kernel>();
169 _interleave_kernel->configure(a, &_tmp_a);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100170 _aux_mem[InterleavedLHS] =
171 MemoryInfo(offset_int_vec(InterleavedLHS), MemoryLifetime::Temporary, _tmp_a.total_size());
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100172
SiCong Lic5ab4df2023-10-17 17:38:57 +0100173 // Configure rhs transpose1xw kernel
174 _transpose1xW_b_kernel = std::make_unique<cpu::kernels::CpuGemmTranspose1xWKernel>();
175 _transpose1xW_b_kernel->configure(b_to_use, &_tmp_b);
176 _aux_mem[Transposed1xWRHS] =
177 MemoryInfo(offset_int_vec(Transposed1xWRHS), MemoryLifetime::Persistent, _tmp_b.total_size());
178
179 // Use a and b here instead of _tmp_a and _tmp_b because CpuGemmMatrixMultiplyKernel requires the original m,n,k in case of interleaved a and transposed1xw b
180 const int m = a->dimension(1);
181 const int n = b_to_use->dimension(0);
182 const int k = a->dimension(0);
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100183
184 // Configure matrix multiplication kernel
SiCong Lic5ab4df2023-10-17 17:38:57 +0100185 _mm_kernel->configure(&_tmp_a, &_tmp_b, gemm_output_to_use, alpha, _run_interleave_transpose,
186 GEMMReshapeInfo(m, n, k));
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100187 }
188
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100189 if (_run_bias_addition)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100190 {
191 _add_bias = std::make_unique<cpu::CpuAdd>();
192 _add_bias->configure(gemm_output_to_use, c, d, ConvertPolicy::SATURATE);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100193 _aux_mem[TempResult] =
194 MemoryInfo(offset_int_vec(TempResult), MemoryLifetime::Temporary, _tmp_d.total_size());
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100195 }
196 }
197
198 // Configure matrix addition kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100199 if (_run_addition)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100200 {
201 _ma_kernel = std::make_unique<cpu::kernels::CpuGemmMatrixAdditionKernel>();
202 _ma_kernel->configure(c, d, beta);
203 }
204
205 // Configure activation
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100206 if (_run_activation)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100207 {
208 _activation_func = std::make_unique<cpu::CpuActivation>();
209 _activation_func->configure(d, nullptr, gemm_info.activation_info());
210 }
211}
212
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100213Status CpuGemm::validate(const ITensorInfo *a,
214 const ITensorInfo *b,
215 const ITensorInfo *c,
216 const ITensorInfo *d,
217 float alpha,
218 float beta,
219 const GEMMInfo &gemm_info)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100220{
221 ARM_COMPUTE_UNUSED(alpha);
Renato Arantes57132942023-04-24 07:19:59 +0000222 const bool is_c_bias = beta == 1 && c != nullptr;
Viet-Hoa Doa3e57c22023-03-13 16:20:04 +0000223 const bool run_addition = c != nullptr && beta != 0 && beta != 1;
SiCong Lic5ab4df2023-10-17 17:38:57 +0100224 // Check if we should use the pretransposed_b or original b
225 // TODO: COMPMID-6597
226 // Note that this check should only apply to the non-optimized path. The reason we brought this at the beginning
227 // instead of only for the fallback path is because of the checks performed below, between here and the run_optimised decision
228 // We should simplify this by
229 // 1. Moving the checks between "fix-start" and "fix-end" into their corresponding ops / kernels (e.g. the weights format checks can and should be moved into CpuGemmAssemblyDispatch)
230 // 2. Moving this b_to_use check back into the non-optimized path
231 TensorInfo pretransposed_b = b->clone()->set_tensor_shape(misc::shape_calculator::compute_transposed_shape(*b));
232 const ITensorInfo *b_to_use = gemm_info.pretranspose_B() ? &pretransposed_b : b;
233 // TODO: COMPMID-6597 fix-start
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100234
235 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
236 ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
237 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32);
Viet-Hoa Doa3e57c22023-03-13 16:20:04 +0000238
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100239 if (is_fixed_format_fast_math(gemm_info.weight_format()))
Jonathan Deakin464ed202023-01-12 11:41:14 +0000240 {
241 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
SiCong Lic5ab4df2023-10-17 17:38:57 +0100242 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b_to_use, DataType::BFLOAT16);
Jonathan Deakin464ed202023-01-12 11:41:14 +0000243 }
244 else
245 {
SiCong Lic5ab4df2023-10-17 17:38:57 +0100246 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b_to_use);
Jonathan Deakin464ed202023-01-12 11:41:14 +0000247 }
248
Renato Arantes57132942023-04-24 07:19:59 +0000249 const int block_by = arm_compute::block_by(gemm_info.weight_format());
Renato Arantes47a50ef2023-06-15 13:40:02 +0000250 // test if im2col has changed the dimensions that are needed for padding
SiCong Lic5ab4df2023-10-17 17:38:57 +0100251 if (a->dimension(0) != b_to_use->dimension(1) && block_by > 1)
Renato Arantes57132942023-04-24 07:19:59 +0000252 {
253 // have to verify bias
254 const size_t dim0_sz = a->dimension(0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100255 ARM_COMPUTE_RETURN_ERROR_ON_MSG(
256 (dim0_sz % block_by) != 0,
257 ("The matrix A number of columns must be a multiple of block_by=" + std::to_string(block_by)).c_str());
Renato Arantes57132942023-04-24 07:19:59 +0000258 // a->dimension(0) = kernel_area * input_channel + kernel_area * input_pad_right
SiCong Lic5ab4df2023-10-17 17:38:57 +0100259 // b_to_use->dimension(1) = kernel_area * input_channel
260 // a->dimension(0) = b_to_use->dimension(1) + kernel_area * input_pad_right
261 const size_t input_pad_right = (dim0_sz - b_to_use->dimension(1)) % block_by;
262 const size_t kernel_area = (dim0_sz - b_to_use->dimension(1)) / input_pad_right;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100263 ARM_COMPUTE_RETURN_ERROR_ON_MSG(
SiCong Lic5ab4df2023-10-17 17:38:57 +0100264 (dim0_sz - kernel_area * input_pad_right) != b_to_use->dimension(1),
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100265 "The product AB is defined only if A number of columns and B number of rows are related");
Renato Arantes57132942023-04-24 07:19:59 +0000266 }
267 else
268 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100269 ARM_COMPUTE_RETURN_ERROR_ON_MSG(
SiCong Lic5ab4df2023-10-17 17:38:57 +0100270 a->dimension(0) != b_to_use->dimension(1),
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100271 "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
Renato Arantes57132942023-04-24 07:19:59 +0000272 }
273
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100274 ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
275 ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100276 if (a->data_type() != DataType::BFLOAT16)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100277 {
278 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, d);
279 }
280
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100281 if (run_addition)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100282 {
283 ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 0);
284 ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.reinterpret_input_as_3d());
285 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(c, d);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100286 ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->dimension(1),
287 "The C matrix must have the same number of rows as the matrix A");
SiCong Lic5ab4df2023-10-17 17:38:57 +0100288 ARM_COMPUTE_RETURN_ERROR_ON_MSG(b_to_use->dimension(0) != c->dimension(0),
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100289 "The C matrix must have the same number of columns as the matrix B");
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100290 }
291
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100292 if (d->total_size() != 0)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100293 {
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000294 // For fixed format we are expecting some kind of blocked format for B/RHS so the dimension won't necessarily match the result matrix any more.
SiCong Lic5ab4df2023-10-17 17:38:57 +0100295 ARM_COMPUTE_RETURN_ERROR_ON(!gemm_info.fixed_format() && b_to_use->dimension(0) != d->dimension(0));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100296 if (gemm_info.depth_output_gemm3d() != 0)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100297 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100298 if (gemm_info.reinterpret_input_as_3d())
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100299 {
300 ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != d->dimension(1));
301 ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(2) != d->dimension(2));
302 }
303 else
304 {
305 ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != d->dimension(1) * d->dimension(2));
306 }
307 }
308 else
309 {
310 ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != d->dimension(1));
311 }
312 }
SiCong Lic5ab4df2023-10-17 17:38:57 +0100313 // TODO: COMPMID-6597 fix-end
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100314
315 // Check if we need to run the optimized assembly kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100316 cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
SiCong Lic5ab4df2023-10-17 17:38:57 +0100317
318 // Note we use b instead of b_to_use here because asm_info also captures the pretranspose_b() flag
319 // so we pass the original b to CpuGemmAssemblyDispatch
320 const bool run_optimised =
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100321 bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, is_c_bias ? c : nullptr, d, asm_info)) &&
322 (c == nullptr || beta == 0.f || beta == 1.f) && // Optimized GeMM doesn't support beta coefficient.
323 !(!b->are_values_constant() &&
324 b->tensor_shape().z() > 1); // Disable batch matmul as optimized GeMM handles batching differently.
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100325
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100326 if (!run_optimised)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100327 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100328 ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.reinterpret_input_as_3d(),
329 "CpuGemm cannot reinterpret the input tensor as 3D");
330 ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.depth_output_gemm3d() != 0,
331 "CpuGemm cannot reinterpret the output tensor as 3D");
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100332
333 // Check if the first input tensor is a vector.
334 const bool run_vector_matrix_multiplication = a->dimension(1) < 2;
335 // Check if we need to reshape the matrix A and matrix B
SiCong Lic5ab4df2023-10-17 17:38:57 +0100336 const bool run_interleave_transpose = !run_vector_matrix_multiplication;
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100337
338 // Arguments used by GEMMReshapeInfo
339 // If we pass the matrix A and matrix B reshaped to CpuGemmMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to GEMMReshapeInfo
340 // in order to know how the matrices have been reshaped
341 const int m = a->dimension(1);
SiCong Lic5ab4df2023-10-17 17:38:57 +0100342 const int n = b_to_use->dimension(0);
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100343 const int k = a->dimension(0);
344 int mult_transpose1xW_width = 1;
345 int mult_interleave4x4_height = 1;
346
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100347 const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(
348 m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d());
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100349
350 const ITensorInfo *matrix_a_info = a;
SiCong Lic5ab4df2023-10-17 17:38:57 +0100351 const ITensorInfo *matrix_b_info = b_to_use;
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100352
353 TensorInfo tmp_a_info{};
354 TensorInfo tmp_b_info{};
355 TensorInfo tmp_output_info = *d->clone();
356
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100357 if (run_interleave_transpose)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100358 {
359 matrix_a_info = &tmp_a_info;
360 matrix_b_info = &tmp_b_info;
361
362 // Validate interleave kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100363 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(
364 *a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())));
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100365 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmInterleave4x4Kernel::validate(a, &tmp_a_info));
366
367 // Validate transpose kernel
SiCong Lic5ab4df2023-10-17 17:38:57 +0100368 auto_init_if_empty(tmp_b_info,
369 b_to_use->clone()->set_tensor_shape(
370 compute_transpose1xW_with_element_size_shape(*b_to_use, mult_transpose1xW_width)));
371 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmTranspose1xWKernel::validate(b_to_use, &tmp_b_info));
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100372 }
373
374 // Validate matrix multiply
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100375 auto_init_if_empty(tmp_output_info,
376 matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(
377 *matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info)));
378 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixMultiplyKernel::validate(
379 matrix_a_info, matrix_b_info, &tmp_output_info, alpha, run_interleave_transpose, reshape_info));
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100380
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100381 if (is_c_bias)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100382 {
383 ARM_COMPUTE_RETURN_ON_ERROR(cpu::CpuAdd::validate(&tmp_output_info, c, d, ConvertPolicy::SATURATE));
384 }
385 }
386
387 // Validate matrix addition kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100388 if (run_addition)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100389 {
390 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixAdditionKernel::validate(c, d, beta));
391 }
392
393 // Validate activation
394 const ActivationLayerInfo &activation = gemm_info.activation_info();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100395 if (activation.enabled())
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100396 {
397 ARM_COMPUTE_RETURN_ON_ERROR(cpu::CpuActivation::validate(d, nullptr, activation));
398 }
399
400 return Status{};
401}
402
403void CpuGemm::run(ITensorPack &tensors)
404{
405 prepare(tensors);
406
407 auto a = tensors.get_const_tensor(ACL_SRC_0);
408 auto b = tensors.get_const_tensor(ACL_SRC_1);
409 auto c = tensors.get_const_tensor(ACL_SRC_2);
410 auto d = tensors.get_tensor(ACL_DST);
411
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100412 if (_asm_glue && _asm_glue->is_configured())
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100413 {
414 // Pass c to asm dispatch only if it's the bias tensor
415 ITensorPack asm_pack = tensors;
Viet-Hoa Do9b0a6b42023-04-03 16:27:25 +0100416 asm_pack.add_const_tensor(ACL_SRC_2, _run_bias_addition ? c : nullptr);
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100417 _asm_glue->run(asm_pack);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100418 if (_run_alpha_scale)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100419 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100420 ITensorPack pack{{ACL_SRC, d}, {ACL_DST, d}};
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100421 _alpha_scale_func->run(pack);
422 }
423 }
424 else
425 {
426 CpuAuxTensorHandler interleaved_a(offset_int_vec(InterleavedLHS), _tmp_a, tensors, true);
SiCong Lic5ab4df2023-10-17 17:38:57 +0100427 CpuAuxTensorHandler pretransposed_b(offset_int_vec(PreTransposedRHS), _pretransposed_b, tensors);
428 CpuAuxTensorHandler transposed1xw_b(offset_int_vec(Transposed1xWRHS), _tmp_b, tensors, true);
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100429 CpuAuxTensorHandler temp_d(offset_int_vec(TempResult), _tmp_d, tensors, true);
430
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100431 ITensorPack mm_pack{{ACL_SRC_0, a}, {ACL_SRC_1, b}, {ACL_DST, (_run_bias_addition) ? temp_d.get() : d}};
SiCong Lic5ab4df2023-10-17 17:38:57 +0100432
433 if (_run_interleave_transpose)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100434 {
435 // Run interleave kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100436 ITensorPack interleave_pack{{ACL_SRC, a}, {ACL_DST, interleaved_a.get()}};
437 NEScheduler::get().schedule_op(_interleave_kernel.get(), Window::DimY, _interleave_kernel->window(),
438 interleave_pack);
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100439 // Use reshaped matrices
440 mm_pack.add_const_tensor(ACL_SRC_0, interleaved_a.get());
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100441 }
442
SiCong Lic5ab4df2023-10-17 17:38:57 +0100443 const ITensor *b_to_use = b;
444 if (_pretranspose_b_func)
445 {
446 if (!_reshape_b_only_on_first_run)
447 {
448 // Run pretranspose kernel
449 ITensorPack pretranspose_pack{{ACL_SRC, b_to_use}, {ACL_DST, pretransposed_b.get()}};
450 _pretranspose_b_func->run(pretranspose_pack);
451 }
452 b_to_use = pretransposed_b.get();
453 }
454 if (_run_interleave_transpose)
455 {
456 if (!_reshape_b_only_on_first_run)
457 {
458 // Run transpose1xw kernel
459 ITensorPack transpose_pack{{ACL_SRC, b_to_use}, {ACL_DST, transposed1xw_b.get()}};
460 NEScheduler::get().schedule_op(_transpose1xW_b_kernel.get(), Window::DimY,
461 _transpose1xW_b_kernel->window(), transpose_pack);
462 }
463 b_to_use = transposed1xw_b.get();
464 }
465 // Use reshaped matrices
466 mm_pack.add_const_tensor(ACL_SRC_1, b_to_use);
467
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100468 NEScheduler::get().schedule_op(_mm_kernel.get(),
469 _run_vector_matrix_multiplication ? Window::DimX : Window::DimY,
470 _mm_kernel->window(), mm_pack);
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100471
472 // Run bias addition kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100473 if (_run_bias_addition)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100474 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100475 ITensorPack pack{{ACL_SRC_0, temp_d.get()}, {ACL_SRC_1, c}, {ACL_DST, d}};
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100476 _add_bias->run(pack);
477 }
478 }
479
480 // Run matrix addition kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100481 if (_run_addition)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100482 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100483 ITensorPack c_add_pack{{ACL_SRC, c}, {ACL_DST, d}};
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100484 NEScheduler::get().schedule_op(_ma_kernel.get(), Window::DimY, _ma_kernel->window(), c_add_pack);
485 }
486
487 // Run activation function
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100488 if (_run_activation)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100489 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100490 ITensorPack pack{{ACL_SRC, d}, {ACL_DST, d}};
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100491 _activation_func->run(pack);
492 }
493}
494
495void CpuGemm::prepare(ITensorPack &tensors)
496{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100497 if (!_is_prepared)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100498 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100499 if (_asm_glue && _asm_glue->is_configured())
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100500 {
501 _asm_glue->prepare(tensors);
502 }
SiCong Lic5ab4df2023-10-17 17:38:57 +0100503 else if (_reshape_b_only_on_first_run)
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100504 {
SiCong Lic5ab4df2023-10-17 17:38:57 +0100505 const ITensor *b = tensors.get_const_tensor(ACL_SRC_1);
506 const ITensor *b_to_use = b;
507 CpuAuxTensorHandler pretransposed_b(
508 offset_int_vec(PreTransposedRHS), _pretransposed_b, tensors,
509 false /*pack_inject: no need to inject into tensors*/,
510 _pretranspose_b_func ==
511 nullptr /*bypass_alloc: no need to allocate if _pretranspose_b_func is not run*/);
512 CpuAuxTensorHandler transposed1xw_b(offset_int_vec(Transposed1xWRHS), _tmp_b, tensors,
513 false /*pack_inject*/, !_run_interleave_transpose /*bypass_alloc*/);
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100514
SiCong Lic5ab4df2023-10-17 17:38:57 +0100515 if (_pretranspose_b_func)
516 {
517 // Run pretranspose kernel
518 ITensorPack pretranspose_pack{{ACL_SRC, b_to_use}, {ACL_DST, pretransposed_b.get()}};
519 _pretranspose_b_func->run(pretranspose_pack);
520 b_to_use = pretransposed_b.get();
521 }
522 if (_run_interleave_transpose)
523 {
524 // Run transpose kernel
525 ITensorPack transpose_pack{{ACL_SRC, b_to_use}, {ACL_DST, transposed1xw_b.get()}};
526 NEScheduler::get().schedule_op(_transpose1xW_b_kernel.get(), Window::DimY,
527 _transpose1xW_b_kernel->window(), transpose_pack);
528 }
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100529 }
530 _is_prepared = true;
531 }
532}
533
534experimental::MemoryRequirements CpuGemm::workspace() const
535{
536 return _aux_mem;
537}
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000538
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100539Status CpuGemm::has_opt_impl(arm_compute::WeightFormat &expected_weight_format,
540 const ITensorInfo *a,
541 const ITensorInfo *b,
542 const ITensorInfo *c,
543 const ITensorInfo *d,
544 const GEMMInfo &gemm_info)
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000545{
546 const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
547
548 return CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, asm_info);
549}
550
551bool CpuGemm::isVarWeightsKernel() const
552{
553 return _asm_glue && _asm_glue->isVarWeightsKernel();
554}
Michele Di Giorgio4dfc5532021-06-30 12:05:34 +0100555} // namespace cpu
556} // namespace arm_compute