blob: 8811a7ea6b850ed7b232a7324123220579c99774 [file] [log] [blame]
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +00001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "src/cpu/operators/CpuMatMul.h"
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +010026#include "arm_compute/core/Types.h"
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000027#include "arm_compute/core/Validate.h"
28#include "arm_compute/core/experimental/Types.h"
29#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Mohammed Suhail Munshi94abde42023-05-25 16:48:43 +010030#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
SiCong Li91295492023-07-21 18:16:13 +010031#include "arm_compute/function_info/MatMulInfo.h"
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000032#include "arm_compute/runtime/NEON/NEScheduler.h"
33#include "arm_compute/runtime/NEON/functions/NEMatMul.h"
34#include "src/common/utils/Log.h"
35#include "src/core/CPP/Validate.h"
36#include "src/core/helpers/AutoConfiguration.h"
37#include "src/core/helpers/MemoryHelpers.h"
Viet-Hoa Doa62129a2023-04-26 15:38:45 +010038#include "src/core/utils/quantization/AsymmHelpers.h"
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000039#include "src/cpu/utils/CpuAuxTensorHandler.h"
40
41using namespace arm_compute::experimental;
42
43namespace arm_compute
44{
45namespace cpu
46{
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +010047namespace
48{
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +010049Status get_gemmlowp_output_stage_info(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const ActivationLayerInfo &act,
50 GEMMLowpOutputStageInfo &gemmlowp_output_stage_info)
51{
52 const auto data_type = src->data_type();
53 const QuantizationInfo oq_info = dst->quantization_info();
54 const UniformQuantizationInfo iq_unif = src->quantization_info().uniform();
55 const UniformQuantizationInfo wq_unif = weights->quantization_info().uniform();
56 const UniformQuantizationInfo oq_unif = oq_info.uniform();
57
58 float multiplier = (iq_unif.scale * wq_unif.scale) / oq_unif.scale;
59 int32_t output_multiplier;
60 int32_t output_shift;
61
62 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
63
Viet-Hoa Doa62129a2023-04-26 15:38:45 +010064 int32_t type_min = 0;
65 int32_t type_max = 0;
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +010066 std::tie(type_min, type_max) = quantization::get_quantized_asymmetric_output_min_max(oq_info, act, data_type);
67
68 gemmlowp_output_stage_info.gemmlowp_multiplier = output_multiplier;
69 gemmlowp_output_stage_info.gemmlowp_shift = output_shift;
70 gemmlowp_output_stage_info.gemmlowp_offset = oq_unif.offset;
71 gemmlowp_output_stage_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
Viet-Hoa Doa62129a2023-04-26 15:38:45 +010072 gemmlowp_output_stage_info.gemmlowp_min_bound = type_min;
73 gemmlowp_output_stage_info.gemmlowp_max_bound = type_max;
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +010074
75 return Status{};
76}
Mohammed Suhail Munshi94abde42023-05-25 16:48:43 +010077} // namespace
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +010078
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000079CpuMatMul::CpuMatMul()
80 : _transpose_kernel_lhs(), _transpose_kernel_rhs(), _asm_glue(), _lhs_transposed(), _rhs_transposed(), _original_lhs_shape(), _original_rhs_shape(), _original_dst_shape()
81{
82}
83
Mohammed Suhail Munshi94abde42023-05-25 16:48:43 +010084Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings, const ActivationLayerInfo &act_info)
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000085{
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +010086 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs, dst);
87 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000088 ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs->are_values_constant(), "LHS Tensor must be dynamic.");
89 ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs->are_values_constant(), "RHS Tensor must be dynamic.");
90 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(lhs);
91 ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(lhs);
92
93 const auto adj_lhs = info.adj_lhs();
94 const auto adj_rhs = info.adj_rhs();
95
96 const ITensorInfo *lhs_to_use = lhs;
97 const ITensorInfo *rhs_to_use = rhs;
98 TensorInfo lhs_transposed{};
99 TensorInfo rhs_transposed{};
100
101 auto gemm_info = AsmGemmInfo();
Mohammed Suhail Munshi94abde42023-05-25 16:48:43 +0100102 gemm_info.activation_info = act_info;
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000103 gemm_info.fast_mode = settings.fast_math();
104
105 // Validate and then permute a/b
106 if(adj_lhs)
107 {
108 auto_init_if_empty(lhs_transposed, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_transposed_shape(*lhs)));
109 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuTransposeKernel::validate(lhs_to_use, &lhs_transposed));
110 // Assign lhs_to_use pointer to use transposed TensorInfo
111 lhs_to_use = &lhs_transposed;
112 }
113 if(adj_rhs)
114 {
115 auto_init_if_empty(rhs_transposed, rhs->clone()->set_tensor_shape(misc::shape_calculator::compute_transposed_shape(*rhs)));
116 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuTransposeKernel::validate(rhs_to_use, &rhs_transposed));
117 // Assign rhs_to_use pointer to use transposed TensorInfo
118 rhs_to_use = &rhs_transposed;
119 }
120
121 ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_to_use->dimension(0) != rhs_to_use->dimension(1),
122 "The product AB is defined only if the number of columns in A is equal to the number of rows in B (after transpose)");
123
Viet-Hoa Do54e52a92023-05-02 16:20:58 +0100124 // Iterate over dimensions to be collapsed in operator - check dimensions are equivalent between tensors
125 for(unsigned int i = 2; i < Coordinates::num_max_dimensions; i++)
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000126 {
127 ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_to_use->dimension(i) != rhs_to_use->dimension(i), "Broadcasting in Batch dimension is unsupported by this operator.");
128 }
129
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +0100130 // Quantized-specific configuration
131 if(is_data_type_quantized(lhs->data_type()))
132 {
133 ARM_COMPUTE_RETURN_ON_ERROR(get_gemmlowp_output_stage_info(lhs_to_use, rhs_to_use, dst, gemm_info.activation_info, gemm_info.output_stage));
134 }
135
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000136 cpu::CpuGemmAssemblyDispatch::validate(lhs_to_use, rhs_to_use, nullptr, dst, gemm_info);
137
138 return Status{};
139}
140
Mohammed Suhail Munshi94abde42023-05-25 16:48:43 +0100141void CpuMatMul::configure(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings, const ActivationLayerInfo &act_info)
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000142{
143 ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst);
144 ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst, info, settings);
145 ARM_COMPUTE_ERROR_THROW_ON(CpuMatMul::validate(lhs, rhs, dst, info, settings));
146
147 _adj_lhs = info.adj_lhs();
148 _adj_rhs = info.adj_rhs();
149 _fast_math = settings.fast_math();
150
151 // 1. Create and reshape tensors
152 // ------------------------------------------------------
153 // a. Clone TensorInfo to prevent changing original tensor values during setup
154 // b. Change shape of lhs/dst to [x, y, 1, collapsed(z)] to match assembly kernel configuration
155 // c. For rhs collapse all dimensions larger than 3 to z dimension
156 TensorInfo lhs_to_use = *lhs->clone();
157 TensorInfo dst_to_use = *dst->clone();
158 TensorInfo rhs_to_use = *rhs->clone();
159
160 // Save starting shape of tensors
161 _original_lhs_shape = lhs_to_use.tensor_shape();
162 _original_dst_shape = dst_to_use.tensor_shape();
163 _original_rhs_shape = rhs_to_use.tensor_shape();
164
165 // Reshape lhs for use with assembly kernels.
166 lhs_to_use.set_tensor_shape(TensorShape(_original_lhs_shape.x(), _original_lhs_shape.y(), 1, _original_lhs_shape.collapsed_from(2).z()));
167 dst_to_use.set_tensor_shape(TensorShape(_original_dst_shape.x(), _original_dst_shape.y(), 1, _original_dst_shape.collapsed_from(2).z()));
168 rhs_to_use.set_tensor_shape(_original_rhs_shape.collapsed_from(2));
169
170 // 2. Configuration for transpose of lhs/rhs
171 // ------------------------------------------------------
172 // Initialise transposed TensorInfo class for aux tensors (intermediary tensors)
173 if(_adj_lhs)
174 {
175 // Setup transpose LHS
176 _transpose_kernel_lhs = std::make_unique<cpu::kernels::CpuTransposeKernel>();
177 _transpose_kernel_lhs->configure(&lhs_to_use, &_lhs_transposed);
178 }
179
180 if(_adj_rhs)
181 {
182 // Setup transpose RHS
183 _transpose_kernel_rhs = std::make_unique<cpu::kernels::CpuTransposeKernel>();
184 _transpose_kernel_rhs->configure(&rhs_to_use, &_rhs_transposed);
185 }
186
187 // 3. Configure assembly kernel using transposed tensors.
188 // -----------------------------------------------------
189 // Use transposed tensors if the corresponding transpose flags are set
190 // Fill AsmGemmInfo class object before configuration
Mohammed Suhail Munshi94abde42023-05-25 16:48:43 +0100191 _gemm_info.activation_info = act_info;
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000192 _gemm_info.fast_mode = settings.fast_math();
Jakub Sujake9b3ee22023-04-17 12:08:48 +0100193 _gemm_info.negated_offsets = false;
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000194
195 lhs_to_use = (_adj_lhs) ? _lhs_transposed : lhs_to_use;
196 rhs_to_use = (_adj_rhs) ? _rhs_transposed : rhs_to_use;
197
Viet-Hoa Do9c7c2d22023-04-11 17:16:27 +0100198 // Quantized-specific configuration
199 if(is_data_type_quantized(lhs->data_type()))
200 {
201 get_gemmlowp_output_stage_info(&lhs_to_use, &rhs_to_use, &dst_to_use, _gemm_info.activation_info, _gemm_info.output_stage);
202 }
203
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000204 // Configure Asm Kernel
205 _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
206 _asm_glue->configure(&lhs_to_use, &rhs_to_use, nullptr, &dst_to_use, _gemm_info); // c is nullptr as bias not supported in MatMul
207
208 // Specify memory requirements for intermediate tensors
209 auto asm_mem_req = _asm_glue->workspace();
210 // Specify memory required by gemm kernel
211 int idx = 0;
212 for(const auto &aux : asm_mem_req)
213 {
214 _aux_mem[idx] = aux;
215 idx++;
216 }
217 // Memory requirements for transposed tensors
218 _aux_mem[TransposeLHS] = MemoryInfo(offset_int_vec(TransposeLHS), MemoryLifetime::Temporary, lhs->total_size());
219 _aux_mem[TransposeRHS] = MemoryInfo(offset_int_vec(TransposeRHS), MemoryLifetime::Temporary, rhs->total_size());
220}
221
222void CpuMatMul::run(ITensorPack &tensors)
223{
224 // Retrieve tensors from tensor pack
225 auto lhs = tensors.get_tensor(ACL_SRC_0);
226 auto rhs = tensors.get_const_tensor(ACL_SRC_1);
227 auto dst = tensors.get_tensor(ACL_DST);
228
229 // Reshape LHS and DST to ensure compatibility with GEMM asm kernel (Batch dimensions is 4th for lhs and dst within asm)
230 // Collapse RHS (necessary to support dimensions larger than 3 in gemm assembly)
231 lhs->info()->set_tensor_shape(TensorShape(_original_lhs_shape.x(), _original_lhs_shape.y(), 1, _original_lhs_shape.collapsed_from(2).z())); // Collapsed 3+ dimensions into z
232 dst->info()->set_tensor_shape(TensorShape(_original_dst_shape.x(), _original_dst_shape.y(), 1, _original_dst_shape.collapsed_from(2).z())); // Collapsed 3+ dimensions into z
233 rhs->info()->set_tensor_shape(_original_rhs_shape.collapsed_from(2));
234
235 // Initialise object to handle stored transposed tensors in auxillary memory
236 CpuAuxTensorHandler lhs_transposed(offset_int_vec(TransposeLHS), _lhs_transposed, tensors, true);
237 CpuAuxTensorHandler rhs_transposed(offset_int_vec(TransposeRHS), _rhs_transposed, tensors, true);
238
239 // Create tensor pack for asm kernel
240 ITensorPack asm_tensors(tensors);
241
242 // Run transpose lhs if necessary
243 if(_adj_lhs)
244 {
245 ITensorPack lhs_transpose_pack = { { TensorType::ACL_SRC, lhs }, { TensorType::ACL_DST, lhs_transposed.get() } };
246 NEScheduler::get().schedule_op(_transpose_kernel_lhs.get(), Window::DimY, _transpose_kernel_lhs->window(), lhs_transpose_pack);
247 asm_tensors.add_const_tensor(TensorType::ACL_SRC_0, lhs_transposed.get());
248 }
249 // Run transpose rhs if necessary
250 if(_adj_rhs)
251 {
252 ITensorPack rhs_transpose_pack = { { TensorType::ACL_SRC, rhs }, { TensorType::ACL_DST, rhs_transposed.get() } };
253 NEScheduler::get().schedule_op(_transpose_kernel_rhs.get(), Window::DimY, _transpose_kernel_rhs->window(), rhs_transpose_pack);
254 asm_tensors.add_const_tensor(TensorType::ACL_SRC_1, rhs_transposed.get());
255 }
256 // Run asm kernel
257 _asm_glue->run(asm_tensors);
258
259 // Undo reshape of tensors
260 dst->info()->set_tensor_shape(_original_dst_shape);
261 lhs->info()->set_tensor_shape(_original_lhs_shape);
262 rhs->info()->set_tensor_shape(_original_rhs_shape);
263}
264
265experimental::MemoryRequirements CpuMatMul::workspace() const
266{
267 return _aux_mem;
268}
269} // namespace cpu
270} // namespace arm_compute