Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 1 | /* |
SiCong Li | 13bab71 | 2023-01-13 15:29:39 +0000 | [diff] [blame] | 2 | * Copyright (c) 2017-2023 Arm Limited. |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 3 | * |
| 4 | * SPDX-License-Identifier: MIT |
| 5 | * |
| 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | * of this software and associated documentation files (the "Software"), to |
| 8 | * deal in the Software without restriction, including without limitation the |
| 9 | * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 10 | * sell copies of the Software, and to permit persons to whom the Software is |
| 11 | * furnished to do so, subject to the following conditions: |
| 12 | * |
| 13 | * The above copyright notice and this permission notice shall be included in all |
| 14 | * copies or substantial portions of the Software. |
| 15 | * |
| 16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | * SOFTWARE. |
| 23 | */ |
Georgios Pinitas | 7891a73 | 2021-08-20 21:39:25 +0100 | [diff] [blame] | 24 | #include "src/gpu/cl/operators/ClGemm.h" |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 25 | |
| 26 | #include "arm_compute/core/CL/CLKernelLibrary.h" |
| 27 | #include "arm_compute/core/CL/ICLTensor.h" |
| 28 | #include "arm_compute/core/Error.h" |
| 29 | #include "arm_compute/core/GPUTarget.h" |
| 30 | #include "arm_compute/core/Helpers.h" |
| 31 | #include "arm_compute/core/KernelDescriptors.h" |
| 32 | #include "arm_compute/core/Log.h" |
| 33 | #include "arm_compute/core/TensorInfo.h" |
| 34 | #include "arm_compute/core/Types.h" |
| 35 | #include "arm_compute/core/Utils.h" |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 36 | #include "arm_compute/core/utils/misc/ShapeCalculator.h" |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 37 | #include "arm_compute/core/Validate.h" |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 38 | #include "arm_compute/runtime/CL/CLScheduler.h" |
| 39 | #include "arm_compute/runtime/ITensorAllocator.h" |
Georgios Pinitas | 2b147ee | 2021-07-08 18:14:45 +0100 | [diff] [blame] | 40 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 41 | #include "src/common/utils/Log.h" |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 42 | #include "src/core/helpers/AutoConfiguration.h" |
| 43 | #include "src/core/helpers/MemoryHelpers.h" |
| 44 | #include "src/core/utils/helpers/float_ops.h" |
Georgios Pinitas | 7891a73 | 2021-08-20 21:39:25 +0100 | [diff] [blame] | 45 | #include "src/gpu/cl/IClKernel.h" |
| 46 | #include "src/gpu/cl/utils/ClAuxTensorHandler.h" |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 47 | #include "src/runtime/CL/gemm/CLGEMMKernelSelection.h" |
| 48 | #include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h" |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 49 | #include "support/Cast.h" |
| 50 | #include "utils/TypePrinter.h" |
| 51 | |
| 52 | namespace arm_compute |
| 53 | { |
| 54 | namespace opencl |
| 55 | { |
| 56 | using namespace arm_compute::misc::shape_calculator; |
| 57 | using namespace arm_compute::cl_gemm; |
| 58 | using namespace arm_compute::experimental; |
| 59 | using namespace arm_compute::utils::cast; |
| 60 | using namespace arm_compute::opencl::kernels; |
| 61 | |
| 62 | namespace |
| 63 | { |
| 64 | inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type) |
| 65 | { |
SiCongLi | 579ca84 | 2021-10-18 09:38:33 +0100 | [diff] [blame] | 66 | return kernel_type == CLGEMMKernelType::NATIVE ? false : true; |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 67 | } |
| 68 | //Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 69 | inline CLGEMMKernelType |
| 70 | auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run, bool constant_weights) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 71 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 72 | if (!constant_weights) |
Giorgio Arena | 4403ed3 | 2021-05-17 13:03:50 +0100 | [diff] [blame] | 73 | { |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 74 | return CLGEMMKernelType::NATIVE; |
Giorgio Arena | 4403ed3 | 2021-05-17 13:03:50 +0100 | [diff] [blame] | 75 | } |
| 76 | |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 77 | auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 78 | if (bool(gemm_kernel)) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 79 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 80 | if (validate_gemm_kernel(gemm_kernel.gemm_type)) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 81 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 82 | ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", |
| 83 | to_string(gemm_kernel.gemm_type).c_str()); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 84 | return gemm_kernel.gemm_type; |
| 85 | } |
| 86 | } |
| 87 | gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 88 | ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", |
| 89 | to_string(gemm_kernel.gemm_type).c_str()); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 90 | return gemm_kernel.gemm_type; |
| 91 | } |
| 92 | // Validate lhs_info and rhs_info for reshaped only rhs kernel |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 93 | inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, |
| 94 | const GEMMRHSMatrixInfo &rhs_info, |
| 95 | const ITensorInfo *a, |
| 96 | const ITensorInfo *b, |
| 97 | const ITensorInfo *c, |
| 98 | const ITensorInfo *output, |
| 99 | GEMMKernelInfo gemm_kernel_info) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 100 | { |
| 101 | // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped only rhs kernel |
| 102 | TensorInfo tmp_b_info{}; |
| 103 | // Validate reshape RHS kernel |
| 104 | auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info))); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 105 | if (!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info))) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 106 | { |
| 107 | return false; |
| 108 | } |
| 109 | // Validate mm kernel |
| 110 | gemm_kernel_info.lhs_info = lhs_info; |
| 111 | gemm_kernel_info.rhs_info = rhs_info; |
| 112 | gemm_kernel_info.has_pad_y = false; |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 113 | if (!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, |
| 114 | rhs_info, gemm_kernel_info))) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 115 | { |
| 116 | return false; |
| 117 | } |
| 118 | gemm_kernel_info.has_pad_y = true; |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 119 | if (!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, |
| 120 | rhs_info, gemm_kernel_info))) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 121 | { |
| 122 | return false; |
| 123 | } |
| 124 | return true; |
| 125 | } |
| 126 | |
| 127 | //Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 128 | inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> |
| 129 | auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, |
| 130 | GEMMKernelInfo kernel_info, |
| 131 | const ITensorInfo *a, |
| 132 | const ITensorInfo *b, |
| 133 | const ITensorInfo *c, |
| 134 | const ITensorInfo *output) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 135 | { |
| 136 | auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(query); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 137 | if (config) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 138 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 139 | if (validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info)) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 140 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 141 | ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE( |
| 142 | "Use reshaped_only_rhs config from mlgo heuristics: LHS info: %s ; RHS info: %s ", |
| 143 | to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str()); |
| 144 | return {config.lhs_info, config.rhs_info}; |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 145 | } |
| 146 | } |
| 147 | config = auto_heuristics::select_default_gemm_config_reshaped_only_rhs(query); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 148 | ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE( |
| 149 | "Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ", |
| 150 | to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str()); |
| 151 | return {config.lhs_info, config.rhs_info}; |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 152 | } |
| 153 | |
| 154 | // Validate lhs_info and rhs_info for reshaped kernel |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 155 | inline bool validate_lhs_rhs_info_reshaped(const GEMMLHSMatrixInfo &lhs_info, |
| 156 | const GEMMRHSMatrixInfo &rhs_info, |
| 157 | const ITensorInfo *a, |
| 158 | const ITensorInfo *b, |
| 159 | const ITensorInfo *c, |
| 160 | const ITensorInfo *output, |
| 161 | GEMMKernelInfo gemm_kernel_info, |
| 162 | bool reinterpret_input_as_3d) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 163 | { |
| 164 | // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped kernel |
| 165 | TensorInfo tmp_a_info{}; |
| 166 | TensorInfo tmp_b_info{}; |
| 167 | |
| 168 | // Validate reshape LHS kernel |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 169 | auto_init_if_empty(tmp_a_info, |
| 170 | a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, reinterpret_input_as_3d))); |
| 171 | if (!bool(ClGemmReshapeLhsMatrixKernel::validate(a, &tmp_a_info, lhs_info, reinterpret_input_as_3d))) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 172 | { |
| 173 | return false; |
| 174 | } |
| 175 | |
| 176 | // Validate reshape RHS kernel |
| 177 | auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info))); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 178 | if (!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info))) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 179 | { |
| 180 | return false; |
| 181 | } |
| 182 | // Validate mm kernel |
| 183 | gemm_kernel_info.lhs_info = lhs_info; |
| 184 | gemm_kernel_info.rhs_info = rhs_info; |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 185 | if (!bool(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, |
| 186 | rhs_info, gemm_kernel_info))) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 187 | { |
| 188 | return false; |
| 189 | } |
| 190 | return true; |
| 191 | } |
| 192 | |
| 193 | //Automatically select between mlgo (prioritized) and default heuristics for reshaped kernel configs |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 194 | inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> |
| 195 | auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery query, |
| 196 | GEMMKernelInfo kernel_info, |
| 197 | const ITensorInfo *a, |
| 198 | const ITensorInfo *b, |
| 199 | const ITensorInfo *c, |
| 200 | const ITensorInfo *output, |
| 201 | bool reinterpret_input_as_3d) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 202 | { |
| 203 | auto config = auto_heuristics::select_mlgo_gemm_config_reshaped(query); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 204 | if (config) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 205 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 206 | if (validate_lhs_rhs_info_reshaped(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info, |
| 207 | reinterpret_input_as_3d)) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 208 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 209 | ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE( |
| 210 | "Use reshaped config from mlgo heuristics: LHS info: %s ; RHS info: %s ", |
| 211 | to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str()); |
| 212 | return {config.lhs_info, config.rhs_info}; |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 213 | } |
| 214 | } |
| 215 | config = auto_heuristics::select_default_gemm_config_reshaped(query); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 216 | ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE( |
| 217 | "Use reshaped config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), |
| 218 | to_string(config.rhs_info).c_str()); |
| 219 | return {config.lhs_info, config.rhs_info}; |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 220 | } |
| 221 | } // namespace |
| 222 | |
| 223 | ClGemm::ClGemm() |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 224 | : _reshape_lhs_kernel(std::make_unique<ClGemmReshapeLhsMatrixKernel>()), |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 225 | _reshape_rhs_kernel(std::make_unique<ClGemmReshapeRhsMatrixKernel>()), |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 226 | _mm_native_kernel(std::make_unique<ClGemmMatrixMultiplyNativeKernel>()), |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 227 | _mm_reshaped_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedKernel>()), |
| 228 | _mm_reshaped_only_rhs_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>()), |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 229 | _mm_reshaped_only_rhs_mmul_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel>()), |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 230 | _tmp_a(), |
| 231 | _tmp_b(), |
| 232 | _reshape_b_only_on_first_run(false), |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 233 | _gemm_kernel_type(CLGEMMKernelType::NATIVE), |
Manuel Bottini | d87aded | 2021-07-16 10:23:31 +0100 | [diff] [blame] | 234 | _is_prepared(false), |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 235 | _aux_mem(AuxTensorIdx::Count) |
| 236 | { |
| 237 | } |
| 238 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 239 | void ClGemm::configure_native(const CLCompileContext &compile_context, |
| 240 | ITensorInfo *a, |
| 241 | ITensorInfo *b, |
| 242 | ITensorInfo *c, |
| 243 | ITensorInfo *output, |
| 244 | float alpha, |
| 245 | float beta, |
| 246 | const GEMMInfo &gemm_info) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 247 | { |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 248 | DataType data_type = a->data_type(); |
| 249 | bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 250 | const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); |
| 251 | const unsigned int n = b->dimension(0); |
| 252 | const unsigned int k = a->dimension(0); |
| 253 | const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); |
| 254 | const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); |
| 255 | const GPUTarget gpu_target = CLScheduler::get().target(); |
| 256 | bool broadcast_bias = gemm_info.broadcast_bias(); |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 257 | |
| 258 | GEMMKernelInfo kernel_info; |
| 259 | kernel_info.m = m; |
| 260 | kernel_info.n = n; |
| 261 | kernel_info.k = k; |
| 262 | kernel_info.depth_output_gemm3d = depth_output_gemm3d; |
| 263 | kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d; |
| 264 | kernel_info.broadcast_bias = broadcast_bias; |
| 265 | kernel_info.activation_info = gemm_info.activation_info(); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 266 | |
| 267 | // Set the target for the kernels |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 268 | _mm_native_kernel->set_target(gpu_target); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 269 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 270 | auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs( |
| 271 | auto_heuristics::CommonQuery{gpu_target, data_type, m, n, k, batch_size}); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 272 | |
| 273 | // Configure and tune matrix multiply kernel |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 274 | _mm_native_kernel->configure(compile_context, a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, |
| 275 | kernel_info); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 276 | } |
| 277 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 278 | void ClGemm::configure_reshaped(const CLCompileContext &compile_context, |
| 279 | ITensorInfo *a, |
| 280 | ITensorInfo *b, |
| 281 | ITensorInfo *c, |
| 282 | ITensorInfo *output, |
| 283 | float alpha, |
| 284 | float beta, |
| 285 | const GEMMInfo &gemm_info) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 286 | { |
| 287 | DataType data_type = a->data_type(); |
| 288 | bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 289 | const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); |
| 290 | const unsigned int n = b->dimension(0); |
| 291 | const unsigned int k = a->dimension(0); |
| 292 | const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); |
| 293 | const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); |
| 294 | const GPUTarget gpu_target = CLScheduler::get().target(); |
| 295 | bool broadcast_bias = gemm_info.broadcast_bias(); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 296 | |
| 297 | GEMMKernelInfo kernel_info; |
| 298 | kernel_info.m = m; |
| 299 | kernel_info.n = n; |
| 300 | kernel_info.k = k; |
| 301 | kernel_info.depth_output_gemm3d = depth_output_gemm3d; |
| 302 | kernel_info.reinterpret_input_as_3d = false; |
| 303 | kernel_info.broadcast_bias = broadcast_bias; |
| 304 | kernel_info.activation_info = gemm_info.activation_info(); |
| 305 | |
| 306 | // Set the target for the kernels |
| 307 | _reshape_lhs_kernel->set_target(gpu_target); |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 308 | _mm_reshaped_kernel->set_target(gpu_target); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 309 | |
| 310 | GEMMLHSMatrixInfo lhs_info{}; |
| 311 | GEMMRHSMatrixInfo rhs_info{}; |
| 312 | |
| 313 | // Pick up the GEMM configuration |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 314 | std::tie(lhs_info, rhs_info) = |
| 315 | auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery{gpu_target, data_type, m, n, k, batch_size}, |
| 316 | kernel_info, a, b, c, output, gemm_info.reinterpret_input_as_3d()); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 317 | |
| 318 | _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d()); |
| 319 | _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info); |
| 320 | |
| 321 | // Configure and tune matrix multiply kernel |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 322 | _mm_reshaped_kernel->configure(compile_context, &_tmp_a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, |
| 323 | kernel_info); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 324 | |
| 325 | // Request memory for LHS and RHS reshape matrix |
| 326 | _aux_mem[LhsReshape] = MemoryInfo(offset_int_vec(LhsReshape), MemoryLifetime::Temporary, _tmp_a.total_size()); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 327 | _aux_mem[RhsReshape] = MemoryInfo( |
| 328 | offset_int_vec(RhsReshape), |
| 329 | _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size()); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 330 | } |
| 331 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 332 | void ClGemm::configure_reshaped_only_rhs(const CLCompileContext &compile_context, |
| 333 | ITensorInfo *a, |
| 334 | ITensorInfo *b, |
| 335 | ITensorInfo *c, |
| 336 | ITensorInfo *output, |
| 337 | float alpha, |
| 338 | float beta, |
| 339 | const GEMMInfo &gemm_info) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 340 | { |
| 341 | DataType data_type = a->data_type(); |
| 342 | bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 343 | const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); |
| 344 | const unsigned int n = b->dimension(0); |
| 345 | const unsigned int k = a->dimension(0); |
| 346 | const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); |
| 347 | const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); |
| 348 | const GPUTarget gpu_target = CLScheduler::get().target(); |
| 349 | bool broadcast_bias = gemm_info.broadcast_bias(); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 350 | |
| 351 | GEMMKernelInfo kernel_info; |
| 352 | kernel_info.m = m; |
| 353 | kernel_info.n = n; |
| 354 | kernel_info.k = k; |
| 355 | kernel_info.depth_output_gemm3d = depth_output_gemm3d; |
| 356 | kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d; |
| 357 | kernel_info.broadcast_bias = broadcast_bias; |
| 358 | kernel_info.activation_info = gemm_info.activation_info(); |
| 359 | |
| 360 | // Set the target for the kernels |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 361 | _mm_reshaped_only_rhs_kernel->set_target(gpu_target); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 362 | |
| 363 | GEMMLHSMatrixInfo lhs_info{}; |
| 364 | GEMMRHSMatrixInfo rhs_info{}; |
| 365 | |
| 366 | // Pick up the GEMM configuration |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 367 | std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs( |
| 368 | auto_heuristics::CommonQuery{gpu_target, data_type, m, n, k, batch_size}, kernel_info, a, b, c, output); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 369 | |
| 370 | // Transpose matrix |
| 371 | _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info); |
| 372 | |
| 373 | // Configure two variants of CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (has_pad_y = false/true) |
| 374 | // During the prepare stage we check the padding requirement for the lhs and dst tensors. If they do not have |
| 375 | // pad y, we dispatch CLGEMMMatrixMultiplyReshapedOnlyRHSKernel with has_pad_y = false |
| 376 | |
| 377 | // Configure matrix multiply kernel with no y padding support |
| 378 | kernel_info.has_pad_y = false; |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 379 | _mm_reshaped_only_rhs_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, |
| 380 | kernel_info); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 381 | |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 382 | // Request memory for RHS reshape matrix |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 383 | _aux_mem[RhsReshape] = MemoryInfo( |
| 384 | offset_int_vec(RhsReshape), |
| 385 | _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size()); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 386 | } |
| 387 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 388 | void ClGemm::configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, |
| 389 | ITensorInfo *a, |
| 390 | ITensorInfo *b, |
| 391 | ITensorInfo *c, |
| 392 | ITensorInfo *output, |
| 393 | float alpha, |
| 394 | float beta, |
| 395 | const GEMMInfo &gemm_info) |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 396 | { |
| 397 | DataType data_type = a->data_type(); |
| 398 | bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 399 | const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); |
| 400 | const unsigned int n = b->dimension(0); |
| 401 | const unsigned int k = a->dimension(0); |
| 402 | const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); |
| 403 | const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); |
| 404 | const GPUTarget gpu_target = CLScheduler::get().target(); |
| 405 | bool broadcast_bias = gemm_info.broadcast_bias(); |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 406 | |
| 407 | GEMMKernelInfo kernel_info; |
| 408 | kernel_info.m = m; |
| 409 | kernel_info.n = n; |
| 410 | kernel_info.k = k; |
| 411 | kernel_info.depth_output_gemm3d = depth_output_gemm3d; |
| 412 | kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d; |
| 413 | kernel_info.broadcast_bias = broadcast_bias; |
| 414 | kernel_info.activation_info = gemm_info.activation_info(); |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 415 | |
| 416 | // Set the target for the kernels |
| 417 | _mm_reshaped_only_rhs_mmul_kernel->set_target(gpu_target); |
| 418 | |
| 419 | GEMMLHSMatrixInfo lhs_info{}; |
| 420 | GEMMRHSMatrixInfo rhs_info{}; |
| 421 | |
| 422 | // Pick up the GEMM configuration |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 423 | auto gemm_config = select_default_gemm_config_reshaped_only_rhs( |
| 424 | auto_heuristics::CommonQuery{gpu_target, data_type, m, n, k, batch_size}); |
| 425 | lhs_info = gemm_config.lhs_info; |
| 426 | rhs_info = gemm_config.rhs_info; |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 427 | // Force H0 to 4 in order to use the MMUL extension |
| 428 | rhs_info.h0 = 4; |
| 429 | |
| 430 | // Reshape Rhs matrix |
| 431 | _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info); |
| 432 | |
| 433 | // Configure matrix multiply kernel with no y padding support |
| 434 | kernel_info.has_pad_y = false; |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 435 | _mm_reshaped_only_rhs_mmul_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, |
| 436 | rhs_info, kernel_info); |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 437 | |
| 438 | // Request memory for RHS reshape matrix |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 439 | _aux_mem[RhsReshape] = MemoryInfo( |
| 440 | offset_int_vec(RhsReshape), |
| 441 | _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size()); |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 442 | } |
| 443 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 444 | Status ClGemm::validate_native(const ITensorInfo *a, |
| 445 | const ITensorInfo *b, |
| 446 | const ITensorInfo *c, |
| 447 | const ITensorInfo *output, |
| 448 | float alpha, |
| 449 | float beta, |
| 450 | const GEMMInfo &gemm_info) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 451 | { |
| 452 | ARM_COMPUTE_UNUSED(alpha); |
| 453 | ARM_COMPUTE_UNUSED(output); |
| 454 | |
| 455 | // Get the GPU target |
| 456 | const GPUTarget gpu_target = CLScheduler::get().target(); |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 457 | DataType data_type = a->data_type(); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 458 | bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 459 | const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); |
| 460 | const unsigned int n = b->dimension(0); |
| 461 | const unsigned int k = a->dimension(0); |
| 462 | const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); |
| 463 | const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); |
| 464 | const bool broadcast_bias = gemm_info.broadcast_bias(); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 465 | |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 466 | GEMMKernelInfo kernel_info; |
| 467 | kernel_info.m = m; |
| 468 | kernel_info.n = n; |
| 469 | kernel_info.k = k; |
| 470 | kernel_info.depth_output_gemm3d = depth_output_gemm3d; |
| 471 | kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d; |
| 472 | kernel_info.broadcast_bias = broadcast_bias; |
| 473 | kernel_info.activation_info = gemm_info.activation_info(); |
| 474 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 475 | auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs( |
| 476 | auto_heuristics::CommonQuery{gpu_target, data_type, m, n, k, batch_size}); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 477 | |
| 478 | // Validate matrix multiply |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 479 | ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyNativeKernel::validate( |
| 480 | a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info)); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 481 | |
| 482 | return Status{}; |
| 483 | } |
| 484 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 485 | Status ClGemm::validate_reshaped(const ITensorInfo *a, |
| 486 | const ITensorInfo *b, |
| 487 | const ITensorInfo *c, |
| 488 | const ITensorInfo *output, |
| 489 | float alpha, |
| 490 | float beta, |
| 491 | const GEMMInfo &gemm_info) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 492 | { |
| 493 | ARM_COMPUTE_UNUSED(alpha); |
| 494 | ARM_COMPUTE_UNUSED(output); |
| 495 | |
| 496 | TensorInfo tmp_a_info{}; |
| 497 | TensorInfo tmp_b_info{}; |
| 498 | |
| 499 | // Get the GPU target |
| 500 | const GPUTarget gpu_target = CLScheduler::get().target(); |
| 501 | DataType data_type = a->data_type(); |
| 502 | bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 503 | const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); |
| 504 | const unsigned int n = b->dimension(0); |
| 505 | const unsigned int k = a->dimension(0); |
| 506 | const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); |
| 507 | const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); |
| 508 | const bool broadcast_bias = gemm_info.broadcast_bias(); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 509 | |
| 510 | GEMMKernelInfo kernel_info; |
| 511 | kernel_info.m = m; |
| 512 | kernel_info.n = n; |
| 513 | kernel_info.k = k; |
| 514 | kernel_info.depth_output_gemm3d = depth_output_gemm3d; |
| 515 | kernel_info.reinterpret_input_as_3d = false; |
| 516 | kernel_info.broadcast_bias = broadcast_bias; |
| 517 | kernel_info.activation_info = gemm_info.activation_info(); |
| 518 | |
| 519 | GEMMLHSMatrixInfo lhs_info; |
| 520 | GEMMRHSMatrixInfo rhs_info; |
| 521 | |
| 522 | // Pick up the GEMM configuration |
| 523 | // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 524 | const auto gemm_config = |
| 525 | select_default_gemm_config_reshaped(auto_heuristics::CommonQuery{gpu_target, data_type, m, n, k, batch_size}); |
| 526 | lhs_info = gemm_config.lhs_info; |
| 527 | rhs_info = gemm_config.rhs_info; |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 528 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 529 | auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape( |
| 530 | compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d()))); |
| 531 | ARM_COMPUTE_RETURN_ON_ERROR( |
| 532 | ClGemmReshapeLhsMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d())); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 533 | |
| 534 | auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info))); |
| 535 | ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)); |
| 536 | |
| 537 | // Validate matrix multiply |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 538 | ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, |
| 539 | beta, lhs_info, rhs_info, kernel_info)); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 540 | |
| 541 | return Status{}; |
| 542 | } |
| 543 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 544 | Status ClGemm::validate_reshaped_only_rhs(const ITensorInfo *a, |
| 545 | const ITensorInfo *b, |
| 546 | const ITensorInfo *c, |
| 547 | const ITensorInfo *output, |
| 548 | float alpha, |
| 549 | float beta, |
| 550 | const GEMMInfo &gemm_info) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 551 | { |
| 552 | ARM_COMPUTE_UNUSED(alpha); |
| 553 | ARM_COMPUTE_UNUSED(output); |
| 554 | |
| 555 | TensorInfo tmp_b_info{}; |
| 556 | |
| 557 | // Get the GPU target |
| 558 | const GPUTarget gpu_target = CLScheduler::get().target(); |
| 559 | const DataType data_type = a->data_type(); |
| 560 | bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 561 | const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); |
| 562 | const unsigned int n = b->dimension(0); |
| 563 | const unsigned int k = a->dimension(0); |
| 564 | const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); |
| 565 | const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); |
| 566 | const bool broadcast_bias = gemm_info.broadcast_bias(); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 567 | |
| 568 | GEMMKernelInfo kernel_info; |
| 569 | kernel_info.m = m; |
| 570 | kernel_info.n = n; |
| 571 | kernel_info.k = k; |
| 572 | kernel_info.depth_output_gemm3d = depth_output_gemm3d; |
| 573 | kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d; |
| 574 | kernel_info.broadcast_bias = broadcast_bias; |
| 575 | kernel_info.activation_info = gemm_info.activation_info(); |
| 576 | |
| 577 | GEMMLHSMatrixInfo lhs_info; |
| 578 | GEMMRHSMatrixInfo rhs_info; |
| 579 | |
| 580 | // Pick up the GEMM configuration |
| 581 | // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 582 | const auto gemm_config = select_default_gemm_config_reshaped_only_rhs( |
| 583 | auto_heuristics::CommonQuery{gpu_target, data_type, m, n, k, batch_size}); |
| 584 | lhs_info = gemm_config.lhs_info; |
| 585 | rhs_info = gemm_config.rhs_info; |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 586 | |
| 587 | auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info))); |
| 588 | ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)); |
| 589 | |
| 590 | // Validate matrix multiply |
| 591 | kernel_info.has_pad_y = false; |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 592 | ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate( |
| 593 | a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info)); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 594 | |
Ramy Elgammal | 451c309 | 2022-02-01 23:01:27 +0000 | [diff] [blame] | 595 | kernel_info.has_pad_y = true; |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 596 | ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate( |
| 597 | a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info)); |
Ramy Elgammal | 451c309 | 2022-02-01 23:01:27 +0000 | [diff] [blame] | 598 | |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 599 | return Status{}; |
| 600 | } |
| 601 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 602 | Status ClGemm::validate_reshaped_only_rhs_mmul(const ITensorInfo *a, |
| 603 | const ITensorInfo *b, |
| 604 | const ITensorInfo *c, |
| 605 | const ITensorInfo *output, |
| 606 | float alpha, |
| 607 | float beta, |
| 608 | const GEMMInfo &gemm_info) |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 609 | { |
| 610 | ARM_COMPUTE_UNUSED(alpha); |
| 611 | ARM_COMPUTE_UNUSED(output); |
| 612 | TensorInfo tmp_b_info{}; |
| 613 | |
| 614 | // Get the GPU target |
| 615 | const GPUTarget gpu_target = CLScheduler::get().target(); |
| 616 | const DataType data_type = a->data_type(); |
| 617 | bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 618 | const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); |
| 619 | const unsigned int n = b->dimension(0); |
| 620 | const unsigned int k = a->dimension(0); |
| 621 | const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); |
| 622 | const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); |
| 623 | const bool broadcast_bias = gemm_info.broadcast_bias(); |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 624 | |
| 625 | GEMMKernelInfo kernel_info; |
| 626 | kernel_info.m = m; |
| 627 | kernel_info.n = n; |
| 628 | kernel_info.k = k; |
| 629 | kernel_info.depth_output_gemm3d = depth_output_gemm3d; |
| 630 | kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d; |
| 631 | kernel_info.broadcast_bias = broadcast_bias; |
| 632 | kernel_info.activation_info = gemm_info.activation_info(); |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 633 | |
| 634 | GEMMLHSMatrixInfo lhs_info; |
| 635 | GEMMRHSMatrixInfo rhs_info; |
| 636 | |
| 637 | // Pick up the GEMM configuration |
| 638 | // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 639 | const auto gemm_config = select_default_gemm_config_reshaped_only_rhs( |
| 640 | auto_heuristics::CommonQuery{gpu_target, data_type, m, n, k, batch_size}); |
| 641 | lhs_info = gemm_config.lhs_info; |
| 642 | rhs_info = gemm_config.rhs_info; |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 643 | // Force H0 to 4 in order to use the MMUL extension |
| 644 | rhs_info.h0 = 4; |
| 645 | |
| 646 | auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info))); |
| 647 | ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)); |
| 648 | |
| 649 | // Validate matrix multiply |
| 650 | kernel_info.has_pad_y = false; |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 651 | ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::validate( |
| 652 | a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info)); |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 653 | |
| 654 | return Status{}; |
| 655 | } |
| 656 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 657 | void ClGemm::configure(const CLCompileContext &compile_context, |
| 658 | ITensorInfo *a, |
| 659 | ITensorInfo *b, |
| 660 | ITensorInfo *c, |
| 661 | ITensorInfo *output, |
| 662 | float alpha, |
| 663 | float beta, |
| 664 | const GEMMInfo &gemm_info) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 665 | { |
| 666 | ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output); |
| 667 | |
| 668 | // Perform validation step |
| 669 | ARM_COMPUTE_ERROR_THROW_ON(validate(a, b, c, output, alpha, beta, gemm_info)); |
ramelg01 | 2e53f17 | 2021-09-22 10:48:25 +0100 | [diff] [blame] | 670 | ARM_COMPUTE_LOG_PARAMS(a, b, c, output, alpha, beta, gemm_info); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 671 | |
| 672 | // Check if we need to reshape the matrix B only on the first run |
| 673 | _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run(); |
Georgios Pinitas | f5d51f3 | 2021-08-17 16:09:10 +0100 | [diff] [blame] | 674 | _is_prepared = gemm_info.retain_internal_weights(); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 675 | |
| 676 | bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 677 | const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); |
| 678 | const unsigned int n = b->dimension(0); |
| 679 | const unsigned int k = a->dimension(0); |
| 680 | const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 681 | |
| 682 | // Select GEMMType |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 683 | _gemm_kernel_type = auto_select_gemm_kernel( |
| 684 | auto_heuristics::CommonQuery{CLScheduler::get().target(), a->data_type(), m, n, k, batch_size}, |
| 685 | _reshape_b_only_on_first_run, b->are_values_constant()); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 686 | |
| 687 | const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); |
| 688 | |
| 689 | ITensorInfo *c_to_use = fuse_add_c ? c : nullptr; |
| 690 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 691 | switch (_gemm_kernel_type) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 692 | { |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 693 | case CLGEMMKernelType::NATIVE: |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 694 | { |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 695 | configure_native(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 696 | break; |
| 697 | } |
| 698 | case CLGEMMKernelType::RESHAPED: |
| 699 | { |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 700 | configure_reshaped(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 701 | break; |
| 702 | } |
| 703 | case CLGEMMKernelType::RESHAPED_ONLY_RHS: |
| 704 | { |
| 705 | configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info); |
| 706 | break; |
| 707 | } |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 708 | case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL: |
| 709 | { |
| 710 | configure_reshaped_only_rhs_mmul(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info); |
| 711 | break; |
| 712 | } |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 713 | default: |
| 714 | { |
| 715 | ARM_COMPUTE_ERROR("GEMMType not supported"); |
| 716 | } |
| 717 | } |
| 718 | } |
| 719 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 720 | Status ClGemm::validate(const ITensorInfo *a, |
| 721 | const ITensorInfo *b, |
| 722 | const ITensorInfo *c, |
| 723 | const ITensorInfo *output, |
| 724 | float alpha, |
| 725 | float beta, |
| 726 | const GEMMInfo &gemm_info) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 727 | { |
| 728 | // Get the GPU target |
| 729 | bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 730 | const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); |
| 731 | const unsigned int n = b->dimension(0); |
| 732 | const unsigned int k = a->dimension(0); |
| 733 | const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 734 | |
SiCong Li | 13bab71 | 2023-01-13 15:29:39 +0000 | [diff] [blame] | 735 | // Check data type early because the auto_select_gemm_kernel has assertions on supported data types |
| 736 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::F16); |
| 737 | |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 738 | // Select GEMMType |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 739 | CLGEMMKernelType gemm_kernel_type = auto_select_gemm_kernel( |
| 740 | auto_heuristics::CommonQuery{ |
| 741 | CLScheduler::get().target(), |
| 742 | a->data_type(), |
| 743 | m, |
| 744 | n, |
| 745 | k, |
| 746 | batch_size, |
| 747 | }, |
| 748 | gemm_info.reshape_b_only_on_first_run(), b->are_values_constant()); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 749 | |
| 750 | const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); |
| 751 | |
| 752 | const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr; |
| 753 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 754 | switch (gemm_kernel_type) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 755 | { |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 756 | case CLGEMMKernelType::NATIVE: |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 757 | { |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 758 | ARM_COMPUTE_RETURN_ON_ERROR(validate_native(a, b, c_to_use, output, alpha, beta, gemm_info)); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 759 | break; |
| 760 | } |
| 761 | case CLGEMMKernelType::RESHAPED: |
| 762 | { |
| 763 | ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info)); |
| 764 | break; |
| 765 | } |
| 766 | case CLGEMMKernelType::RESHAPED_ONLY_RHS: |
| 767 | { |
| 768 | ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info)); |
| 769 | break; |
| 770 | } |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 771 | case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL: |
| 772 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 773 | ARM_COMPUTE_RETURN_ON_ERROR( |
| 774 | validate_reshaped_only_rhs_mmul(a, b, c_to_use, output, alpha, beta, gemm_info)); |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 775 | break; |
| 776 | } |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 777 | default: |
| 778 | { |
| 779 | ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported"); |
| 780 | } |
| 781 | } |
| 782 | |
| 783 | return Status{}; |
| 784 | } |
| 785 | |
| 786 | void ClGemm::run(ITensorPack &tensors) |
| 787 | { |
SiCongLi | afa1972 | 2021-10-24 19:12:33 +0100 | [diff] [blame] | 788 | const ITensor *lhs = tensors.get_const_tensor(ACL_SRC_0); |
| 789 | const ITensor *rhs = tensors.get_const_tensor(ACL_SRC_1); |
| 790 | ITensor *dst = tensors.get_tensor(ACL_DST); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 791 | |
| 792 | ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, dst); |
| 793 | |
| 794 | CLAuxTensorHandler lhs_reshaped(offset_int_vec(LhsReshape), _tmp_a, tensors, true); |
| 795 | CLAuxTensorHandler rhs_reshaped(offset_int_vec(RhsReshape), _tmp_b, tensors, true); |
| 796 | |
| 797 | // Prepare the consts if needed |
| 798 | prepare(tensors); |
| 799 | |
| 800 | // Run matrix multiply kernel |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 801 | switch (_gemm_kernel_type) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 802 | { |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 803 | case CLGEMMKernelType::NATIVE: |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 804 | { |
Gian Marco Iodice | c9cecc0 | 2021-10-15 10:23:24 +0100 | [diff] [blame] | 805 | CLScheduler::get().enqueue_op(*_mm_native_kernel, tensors, true); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 806 | break; |
| 807 | } |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 808 | case CLGEMMKernelType::RESHAPED: |
| 809 | { |
| 810 | // Run interleave kernel |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 811 | ITensorPack reshape_lhs_pack{{ACL_SRC, lhs}, {ACL_DST, lhs_reshaped.get()}}; |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 812 | CLScheduler::get().enqueue_op(*_reshape_lhs_kernel, reshape_lhs_pack, false); |
| 813 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 814 | if (!_reshape_b_only_on_first_run) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 815 | { |
| 816 | // Run transpose kernel |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 817 | ITensorPack reshape_rhs_pack{{ACL_SRC, rhs}, {ACL_DST, rhs_reshaped.get()}}; |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 818 | CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false); |
| 819 | } |
SiCongLi | 579ca84 | 2021-10-18 09:38:33 +0100 | [diff] [blame] | 820 | // Copy original tensor pack and overwrite lhs and rhs with reshaped counterparts |
| 821 | ITensorPack gemm_reshaped_pack(tensors); |
| 822 | gemm_reshaped_pack.add_const_tensor(ACL_SRC_0, lhs_reshaped.get()); |
| 823 | gemm_reshaped_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get()); |
Manuel Bottini | d87aded | 2021-07-16 10:23:31 +0100 | [diff] [blame] | 824 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 825 | if (_gemm_kernel_type == CLGEMMKernelType::RESHAPED) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 826 | { |
| 827 | CLScheduler::get().enqueue_op(*_mm_reshaped_kernel, gemm_reshaped_pack, true); |
| 828 | } |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 829 | break; |
| 830 | } |
| 831 | case CLGEMMKernelType::RESHAPED_ONLY_RHS: |
| 832 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 833 | if (!_reshape_b_only_on_first_run) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 834 | { |
| 835 | // Run transpose kernel |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 836 | ITensorPack reshape_rhs_pack{{ACL_SRC, rhs}, {ACL_DST, rhs_reshaped.get()}}; |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 837 | CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false); |
| 838 | } |
| 839 | // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement |
| 840 | // Check if the lhs or dst tensors have padding |
| 841 | const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom; |
| 842 | const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom; |
| 843 | bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0); |
| 844 | |
SiCongLi | afa1972 | 2021-10-24 19:12:33 +0100 | [diff] [blame] | 845 | // Copy original tensor pack and overwrite rhs with reshaped counterpart |
| 846 | ITensorPack gemm_reshaped_onlyrhs_pack(tensors); |
| 847 | gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get()); |
| 848 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 849 | if (has_pad_y) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 850 | { |
ramelg01 | 9cca592 | 2021-11-11 10:05:00 +0000 | [diff] [blame] | 851 | ARM_COMPUTE_ERROR_ON(has_pad_y); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 852 | } |
| 853 | else |
| 854 | { |
| 855 | CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_kernel, gemm_reshaped_onlyrhs_pack, true); |
| 856 | } |
| 857 | break; |
| 858 | } |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 859 | case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL: |
| 860 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 861 | if (!_reshape_b_only_on_first_run) |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 862 | { |
| 863 | // Run transpose kernel |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 864 | ITensorPack reshape_rhs_pack{{ACL_SRC, rhs}, {ACL_DST, rhs_reshaped.get()}}; |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 865 | CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false); |
| 866 | } |
| 867 | // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement |
| 868 | // Check if the lhs or dst tensors have padding |
| 869 | const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom; |
| 870 | const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom; |
| 871 | bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0); |
| 872 | |
| 873 | // Copy original tensor pack and overwrite rhs with reshaped counterpart |
| 874 | ITensorPack gemm_reshaped_onlyrhs_pack(tensors); |
| 875 | gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get()); |
| 876 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 877 | if (has_pad_y) |
Gunes Bayir | 4bfc70e | 2021-12-10 16:17:56 +0000 | [diff] [blame] | 878 | { |
| 879 | ARM_COMPUTE_ERROR_ON(has_pad_y); |
| 880 | } |
| 881 | else |
| 882 | { |
| 883 | CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_mmul_kernel, gemm_reshaped_onlyrhs_pack, true); |
| 884 | } |
| 885 | break; |
| 886 | } |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 887 | default: |
| 888 | { |
| 889 | ARM_COMPUTE_ERROR("GEMMType not supported"); |
| 890 | } |
| 891 | } |
| 892 | } |
| 893 | |
| 894 | void ClGemm::prepare(ITensorPack &constants) |
| 895 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 896 | if (!_is_prepared) |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 897 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 898 | const ITensor *src1 = constants.get_const_tensor(ACL_SRC_1); |
| 899 | ICLTensor *rhs_aux = |
| 900 | utils::cast::polymorphic_downcast<ICLTensor *>(constants.get_tensor(offset_int_vec(RhsReshape))); |
Georgios Pinitas | 2b147ee | 2021-07-08 18:14:45 +0100 | [diff] [blame] | 901 | |
Manuel Bottini | d87aded | 2021-07-16 10:23:31 +0100 | [diff] [blame] | 902 | // If memory for RHS is persistent and src1 is provided re-transform else assume that RHS is transformed |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 903 | if ((_aux_mem[AuxTensorIdx::RhsReshape].lifetime == MemoryLifetime::Persistent) && |
| 904 | (src1 != nullptr && rhs_aux != nullptr) && rhs_aux) |
Manuel Bottini | d87aded | 2021-07-16 10:23:31 +0100 | [diff] [blame] | 905 | { |
| 906 | ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Transforming RHS Matrix!"); |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 907 | |
Manuel Bottini | d87aded | 2021-07-16 10:23:31 +0100 | [diff] [blame] | 908 | CLAuxTensorHandler rhs_reshaped(_tmp_b, *rhs_aux); |
| 909 | ARM_COMPUTE_ERROR_ON(rhs_reshaped.get()->cl_buffer().get() == nullptr); |
| 910 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame^] | 911 | ITensorPack reshape_rhs_pack{{ACL_SRC, src1}, {ACL_DST, rhs_reshaped.get()}}; |
Manuel Bottini | d87aded | 2021-07-16 10:23:31 +0100 | [diff] [blame] | 912 | CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, true); |
| 913 | } |
| 914 | _is_prepared = true; |
Georgios Pinitas | 856f66e | 2021-04-22 21:13:21 +0100 | [diff] [blame] | 915 | } |
| 916 | } |
| 917 | |
| 918 | experimental::MemoryRequirements ClGemm::workspace() const |
| 919 | { |
| 920 | return _aux_mem; |
| 921 | } |
| 922 | } // namespace opencl |
| 923 | } // namespace arm_compute |