blob: 515b51104443e9f7adfccb57250208c71c23c65a [file] [log] [blame]
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +00001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "src/cpu/operators/CpuMatMul.h"
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +010026#include "arm_compute/core/Types.h"
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000027#include "arm_compute/core/Validate.h"
28#include "arm_compute/core/experimental/Types.h"
29#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Mohammed Suhail Munshi94abde42023-05-25 16:48:43 +010030#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000031#include "arm_compute/runtime/NEON/NEScheduler.h"
32#include "arm_compute/runtime/NEON/functions/NEMatMul.h"
33#include "src/common/utils/Log.h"
34#include "src/core/CPP/Validate.h"
35#include "src/core/helpers/AutoConfiguration.h"
36#include "src/core/helpers/MemoryHelpers.h"
Viet-Hoa Doa62129a2023-04-26 15:38:45 +010037#include "src/core/utils/quantization/AsymmHelpers.h"
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000038#include "src/cpu/utils/CpuAuxTensorHandler.h"
39
40using namespace arm_compute::experimental;
41
42namespace arm_compute
43{
44namespace cpu
45{
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +010046namespace
47{
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +010048Status get_gemmlowp_output_stage_info(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const ActivationLayerInfo &act,
49 GEMMLowpOutputStageInfo &gemmlowp_output_stage_info)
50{
51 const auto data_type = src->data_type();
52 const QuantizationInfo oq_info = dst->quantization_info();
53 const UniformQuantizationInfo iq_unif = src->quantization_info().uniform();
54 const UniformQuantizationInfo wq_unif = weights->quantization_info().uniform();
55 const UniformQuantizationInfo oq_unif = oq_info.uniform();
56
57 float multiplier = (iq_unif.scale * wq_unif.scale) / oq_unif.scale;
58 int32_t output_multiplier;
59 int32_t output_shift;
60
61 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
62
Viet-Hoa Doa62129a2023-04-26 15:38:45 +010063 int32_t type_min = 0;
64 int32_t type_max = 0;
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +010065 std::tie(type_min, type_max) = quantization::get_quantized_asymmetric_output_min_max(oq_info, act, data_type);
66
67 gemmlowp_output_stage_info.gemmlowp_multiplier = output_multiplier;
68 gemmlowp_output_stage_info.gemmlowp_shift = output_shift;
69 gemmlowp_output_stage_info.gemmlowp_offset = oq_unif.offset;
70 gemmlowp_output_stage_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
Viet-Hoa Doa62129a2023-04-26 15:38:45 +010071 gemmlowp_output_stage_info.gemmlowp_min_bound = type_min;
72 gemmlowp_output_stage_info.gemmlowp_max_bound = type_max;
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +010073
74 return Status{};
75}
Mohammed Suhail Munshi94abde42023-05-25 16:48:43 +010076} // namespace
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +010077
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000078CpuMatMul::CpuMatMul()
79 : _transpose_kernel_lhs(), _transpose_kernel_rhs(), _asm_glue(), _lhs_transposed(), _rhs_transposed(), _original_lhs_shape(), _original_rhs_shape(), _original_dst_shape()
80{
81}
82
Mohammed Suhail Munshi94abde42023-05-25 16:48:43 +010083Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings, const ActivationLayerInfo &act_info)
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000084{
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +010085 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs, dst);
86 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000087 ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs->are_values_constant(), "LHS Tensor must be dynamic.");
88 ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs->are_values_constant(), "RHS Tensor must be dynamic.");
89 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(lhs);
90 ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(lhs);
91
92 const auto adj_lhs = info.adj_lhs();
93 const auto adj_rhs = info.adj_rhs();
94
95 const ITensorInfo *lhs_to_use = lhs;
96 const ITensorInfo *rhs_to_use = rhs;
97 TensorInfo lhs_transposed{};
98 TensorInfo rhs_transposed{};
99
100 auto gemm_info = AsmGemmInfo();
Mohammed Suhail Munshi94abde42023-05-25 16:48:43 +0100101 gemm_info.activation_info = act_info;
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000102 gemm_info.fast_mode = settings.fast_math();
103
104 // Validate and then permute a/b
105 if(adj_lhs)
106 {
107 auto_init_if_empty(lhs_transposed, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_transposed_shape(*lhs)));
108 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuTransposeKernel::validate(lhs_to_use, &lhs_transposed));
109 // Assign lhs_to_use pointer to use transposed TensorInfo
110 lhs_to_use = &lhs_transposed;
111 }
112 if(adj_rhs)
113 {
114 auto_init_if_empty(rhs_transposed, rhs->clone()->set_tensor_shape(misc::shape_calculator::compute_transposed_shape(*rhs)));
115 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuTransposeKernel::validate(rhs_to_use, &rhs_transposed));
116 // Assign rhs_to_use pointer to use transposed TensorInfo
117 rhs_to_use = &rhs_transposed;
118 }
119
120 ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_to_use->dimension(0) != rhs_to_use->dimension(1),
121 "The product AB is defined only if the number of columns in A is equal to the number of rows in B (after transpose)");
122
Viet-Hoa Do54e52a92023-05-02 16:20:58 +0100123 // Iterate over dimensions to be collapsed in operator - check dimensions are equivalent between tensors
124 for(unsigned int i = 2; i < Coordinates::num_max_dimensions; i++)
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000125 {
126 ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_to_use->dimension(i) != rhs_to_use->dimension(i), "Broadcasting in Batch dimension is unsupported by this operator.");
127 }
128
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +0100129 // Quantized-specific configuration
130 if(is_data_type_quantized(lhs->data_type()))
131 {
132 ARM_COMPUTE_RETURN_ON_ERROR(get_gemmlowp_output_stage_info(lhs_to_use, rhs_to_use, dst, gemm_info.activation_info, gemm_info.output_stage));
133 }
134
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000135 cpu::CpuGemmAssemblyDispatch::validate(lhs_to_use, rhs_to_use, nullptr, dst, gemm_info);
136
137 return Status{};
138}
139
Mohammed Suhail Munshi94abde42023-05-25 16:48:43 +0100140void CpuMatMul::configure(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings, const ActivationLayerInfo &act_info)
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000141{
142 ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst);
143 ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst, info, settings);
144 ARM_COMPUTE_ERROR_THROW_ON(CpuMatMul::validate(lhs, rhs, dst, info, settings));
145
146 _adj_lhs = info.adj_lhs();
147 _adj_rhs = info.adj_rhs();
148 _fast_math = settings.fast_math();
149
150 // 1. Create and reshape tensors
151 // ------------------------------------------------------
152 // a. Clone TensorInfo to prevent changing original tensor values during setup
153 // b. Change shape of lhs/dst to [x, y, 1, collapsed(z)] to match assembly kernel configuration
154 // c. For rhs collapse all dimensions larger than 3 to z dimension
155 TensorInfo lhs_to_use = *lhs->clone();
156 TensorInfo dst_to_use = *dst->clone();
157 TensorInfo rhs_to_use = *rhs->clone();
158
159 // Save starting shape of tensors
160 _original_lhs_shape = lhs_to_use.tensor_shape();
161 _original_dst_shape = dst_to_use.tensor_shape();
162 _original_rhs_shape = rhs_to_use.tensor_shape();
163
164 // Reshape lhs for use with assembly kernels.
165 lhs_to_use.set_tensor_shape(TensorShape(_original_lhs_shape.x(), _original_lhs_shape.y(), 1, _original_lhs_shape.collapsed_from(2).z()));
166 dst_to_use.set_tensor_shape(TensorShape(_original_dst_shape.x(), _original_dst_shape.y(), 1, _original_dst_shape.collapsed_from(2).z()));
167 rhs_to_use.set_tensor_shape(_original_rhs_shape.collapsed_from(2));
168
169 // 2. Configuration for transpose of lhs/rhs
170 // ------------------------------------------------------
171 // Initialise transposed TensorInfo class for aux tensors (intermediary tensors)
172 if(_adj_lhs)
173 {
174 // Setup transpose LHS
175 _transpose_kernel_lhs = std::make_unique<cpu::kernels::CpuTransposeKernel>();
176 _transpose_kernel_lhs->configure(&lhs_to_use, &_lhs_transposed);
177 }
178
179 if(_adj_rhs)
180 {
181 // Setup transpose RHS
182 _transpose_kernel_rhs = std::make_unique<cpu::kernels::CpuTransposeKernel>();
183 _transpose_kernel_rhs->configure(&rhs_to_use, &_rhs_transposed);
184 }
185
186 // 3. Configure assembly kernel using transposed tensors.
187 // -----------------------------------------------------
188 // Use transposed tensors if the corresponding transpose flags are set
189 // Fill AsmGemmInfo class object before configuration
Mohammed Suhail Munshi94abde42023-05-25 16:48:43 +0100190 _gemm_info.activation_info = act_info;
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000191 _gemm_info.fast_mode = settings.fast_math();
Jakub Sujake9b3ee22023-04-17 12:08:48 +0100192 _gemm_info.negated_offsets = false;
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000193
194 lhs_to_use = (_adj_lhs) ? _lhs_transposed : lhs_to_use;
195 rhs_to_use = (_adj_rhs) ? _rhs_transposed : rhs_to_use;
196
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +0100197 // Quantized-specific configuration
198 if(is_data_type_quantized(lhs->data_type()))
199 {
200 get_gemmlowp_output_stage_info(&lhs_to_use, &rhs_to_use, &dst_to_use, _gemm_info.activation_info, _gemm_info.output_stage);
201 }
202
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000203 // Configure Asm Kernel
204 _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
205 _asm_glue->configure(&lhs_to_use, &rhs_to_use, nullptr, &dst_to_use, _gemm_info); // c is nullptr as bias not supported in MatMul
206
207 // Specify memory requirements for intermediate tensors
208 auto asm_mem_req = _asm_glue->workspace();
209 // Specify memory required by gemm kernel
210 int idx = 0;
211 for(const auto &aux : asm_mem_req)
212 {
213 _aux_mem[idx] = aux;
214 idx++;
215 }
216 // Memory requirements for transposed tensors
217 _aux_mem[TransposeLHS] = MemoryInfo(offset_int_vec(TransposeLHS), MemoryLifetime::Temporary, lhs->total_size());
218 _aux_mem[TransposeRHS] = MemoryInfo(offset_int_vec(TransposeRHS), MemoryLifetime::Temporary, rhs->total_size());
219}
220
221void CpuMatMul::run(ITensorPack &tensors)
222{
223 // Retrieve tensors from tensor pack
224 auto lhs = tensors.get_tensor(ACL_SRC_0);
225 auto rhs = tensors.get_const_tensor(ACL_SRC_1);
226 auto dst = tensors.get_tensor(ACL_DST);
227
228 // Reshape LHS and DST to ensure compatibility with GEMM asm kernel (Batch dimensions is 4th for lhs and dst within asm)
229 // Collapse RHS (necessary to support dimensions larger than 3 in gemm assembly)
230 lhs->info()->set_tensor_shape(TensorShape(_original_lhs_shape.x(), _original_lhs_shape.y(), 1, _original_lhs_shape.collapsed_from(2).z())); // Collapsed 3+ dimensions into z
231 dst->info()->set_tensor_shape(TensorShape(_original_dst_shape.x(), _original_dst_shape.y(), 1, _original_dst_shape.collapsed_from(2).z())); // Collapsed 3+ dimensions into z
232 rhs->info()->set_tensor_shape(_original_rhs_shape.collapsed_from(2));
233
234 // Initialise object to handle stored transposed tensors in auxillary memory
235 CpuAuxTensorHandler lhs_transposed(offset_int_vec(TransposeLHS), _lhs_transposed, tensors, true);
236 CpuAuxTensorHandler rhs_transposed(offset_int_vec(TransposeRHS), _rhs_transposed, tensors, true);
237
238 // Create tensor pack for asm kernel
239 ITensorPack asm_tensors(tensors);
240
241 // Run transpose lhs if necessary
242 if(_adj_lhs)
243 {
244 ITensorPack lhs_transpose_pack = { { TensorType::ACL_SRC, lhs }, { TensorType::ACL_DST, lhs_transposed.get() } };
245 NEScheduler::get().schedule_op(_transpose_kernel_lhs.get(), Window::DimY, _transpose_kernel_lhs->window(), lhs_transpose_pack);
246 asm_tensors.add_const_tensor(TensorType::ACL_SRC_0, lhs_transposed.get());
247 }
248 // Run transpose rhs if necessary
249 if(_adj_rhs)
250 {
251 ITensorPack rhs_transpose_pack = { { TensorType::ACL_SRC, rhs }, { TensorType::ACL_DST, rhs_transposed.get() } };
252 NEScheduler::get().schedule_op(_transpose_kernel_rhs.get(), Window::DimY, _transpose_kernel_rhs->window(), rhs_transpose_pack);
253 asm_tensors.add_const_tensor(TensorType::ACL_SRC_1, rhs_transposed.get());
254 }
255 // Run asm kernel
256 _asm_glue->run(asm_tensors);
257
258 // Undo reshape of tensors
259 dst->info()->set_tensor_shape(_original_dst_shape);
260 lhs->info()->set_tensor_shape(_original_lhs_shape);
261 rhs->info()->set_tensor_shape(_original_rhs_shape);
262}
263
264experimental::MemoryRequirements CpuMatMul::workspace() const
265{
266 return _aux_mem;
267}
268} // namespace cpu
269} // namespace arm_compute