Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 1 | /* |
Gunes Bayir | ef63739 | 2024-02-12 21:32:51 +0000 | [diff] [blame] | 2 | * Copyright (c) 2018-2024 Arm Limited. |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 3 | * |
| 4 | * SPDX-License-Identifier: MIT |
| 5 | * |
| 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | * of this software and associated documentation files (the "Software"), to |
| 8 | * deal in the Software without restriction, including without limitation the |
| 9 | * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 10 | * sell copies of the Software, and to permit persons to whom the Software is |
| 11 | * furnished to do so, subject to the following conditions: |
| 12 | * |
| 13 | * The above copyright notice and this permission notice shall be included in all |
| 14 | * copies or substantial portions of the Software. |
| 15 | * |
| 16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | * SOFTWARE. |
| 23 | */ |
Georgios Pinitas | 7891a73 | 2021-08-20 21:39:25 +0100 | [diff] [blame] | 24 | #include "src/cpu/operators/internal/CpuGemmAssemblyDispatch.h" |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 25 | |
Sang-Hoon Park | 68dd25f | 2020-10-19 16:00:11 +0100 | [diff] [blame] | 26 | #include "arm_compute/runtime/NEON/NEScheduler.h" |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 27 | |
Sang-Hoon Park | 68dd25f | 2020-10-19 16:00:11 +0100 | [diff] [blame] | 28 | #include "src/core/CPP/Validate.h" |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 29 | #include "src/core/helpers/MemoryHelpers.h" |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 30 | #include "src/core/NEON/kernels/arm_gemm/utils.hpp" |
Michele Di Giorgio | d02d5ed | 2021-01-22 09:47:04 +0000 | [diff] [blame] | 31 | #include "src/core/utils/AssemblyUtils.h" |
Georgios Pinitas | 7891a73 | 2021-08-20 21:39:25 +0100 | [diff] [blame] | 32 | #include "src/cpu/kernels/assembly/arm_gemm.hpp" |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 33 | #include "src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h" |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 34 | #include "src/cpu/operators/CpuTranspose.h" |
Georgios Pinitas | 7891a73 | 2021-08-20 21:39:25 +0100 | [diff] [blame] | 35 | #include "src/cpu/utils/CpuAuxTensorHandler.h" |
Michele Di Giorgio | 6ad60af | 2020-06-09 14:52:15 +0100 | [diff] [blame] | 36 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 37 | #include <arm_neon.h> |
| 38 | |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 39 | namespace arm_compute |
| 40 | { |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 41 | namespace cpu |
| 42 | { |
SiCong Li | dba672c | 2023-04-06 16:30:18 +0100 | [diff] [blame] | 43 | namespace |
| 44 | { |
| 45 | /** Run pretranspose_B_array in parallel (1D static scheduling) |
| 46 | * |
| 47 | * @tparam TypeInput |
| 48 | * @tparam TypeOutput |
| 49 | * |
| 50 | * @param[in] gemm_asm GemmCommon kernel to run |
| 51 | * @param[in] dst Pretransposed B array |
| 52 | * @param[in] src B array to be pretransposed |
| 53 | * @param[in] src_ld Stride in y |
| 54 | * @param[in] src_multi_stride Stride in z ("multi") |
| 55 | * @param[in] num_threads Number of threads to run this method. Must be >= 1 |
| 56 | */ |
| 57 | template <typename TypeInput, typename TypeOutput> |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 58 | void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutput> *gemm_asm, |
| 59 | ITensor *dst, |
| 60 | const TypeInput *src, |
| 61 | int src_ld, |
| 62 | int src_multi_stride, |
Gunes Bayir | ef63739 | 2024-02-12 21:32:51 +0000 | [diff] [blame] | 63 | unsigned int num_threads, |
| 64 | bool transpose) |
SiCong Li | dba672c | 2023-04-06 16:30:18 +0100 | [diff] [blame] | 65 | { |
| 66 | ARM_COMPUTE_ERROR_ON(gemm_asm == nullptr); |
| 67 | ARM_COMPUTE_ERROR_ON(num_threads == 0); |
| 68 | // The window size is also the total workload size |
| 69 | const unsigned int wsize = gemm_asm->get_B_pretranspose_window_size(); |
| 70 | |
| 71 | std::vector<IScheduler::Workload> workloads(num_threads); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 72 | for (unsigned int t = 0; t < num_threads; ++t) |
SiCong Li | dba672c | 2023-04-06 16:30:18 +0100 | [diff] [blame] | 73 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 74 | workloads[t] = [=](const ThreadInfo &info) |
SiCong Li | dba672c | 2023-04-06 16:30:18 +0100 | [diff] [blame] | 75 | { |
| 76 | const unsigned int start = (info.thread_id * wsize) / num_threads; |
| 77 | const unsigned int end = ((info.thread_id + 1) * wsize) / num_threads; |
| 78 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 79 | if (start < end) |
SiCong Li | dba672c | 2023-04-06 16:30:18 +0100 | [diff] [blame] | 80 | { |
Gunes Bayir | ef63739 | 2024-02-12 21:32:51 +0000 | [diff] [blame] | 81 | gemm_asm->pretranspose_B_array_part(dst->buffer(), src, src_ld, src_multi_stride, transpose, start, |
| 82 | end); |
SiCong Li | dba672c | 2023-04-06 16:30:18 +0100 | [diff] [blame] | 83 | } |
| 84 | }; |
| 85 | } |
| 86 | NEScheduler::get().run_tagged_workloads(workloads, "CpuGemmAssemblyDispatch/pretranspose_B_array"); |
| 87 | } |
| 88 | } // namespace |
| 89 | |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 90 | using namespace arm_compute::experimental; |
| 91 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 92 | namespace |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 93 | { |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 94 | struct free_delete |
| 95 | { |
| 96 | void operator()(void *x) |
| 97 | { |
| 98 | free(x); |
| 99 | } |
| 100 | }; |
| 101 | |
| 102 | struct Params |
| 103 | { |
| 104 | unsigned int M; |
| 105 | unsigned int N; |
| 106 | unsigned int K; |
| 107 | unsigned int batches; |
| 108 | unsigned int multis; |
| 109 | unsigned int sections; |
| 110 | bool indirect; |
| 111 | }; |
| 112 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 113 | Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 114 | { |
| 115 | ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 116 | Params p; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 117 | p.M = d->tensor_shape().y(); |
| 118 | p.K = a->tensor_shape().x(); |
| 119 | p.N = d->tensor_shape().x(); |
Georgios Pinitas | 4c634e0 | 2020-12-01 02:17:19 +0000 | [diff] [blame] | 120 | p.batches = 1; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 121 | p.multis = 1; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 122 | p.sections = 1; |
Georgios Pinitas | 4c634e0 | 2020-12-01 02:17:19 +0000 | [diff] [blame] | 123 | p.indirect = false; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 124 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 125 | if (info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 126 | { |
| 127 | p.indirect = true; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 128 | p.sections = b->tensor_shape()[2] * b->tensor_shape()[3]; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 129 | } |
| 130 | else |
| 131 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 132 | p.multis = b->tensor_shape().z(); |
| 133 | p.batches = d->tensor_shape().total_size_upper(2) / p.multis; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 134 | } |
| 135 | |
| 136 | // Update M in case of GEMM3D for output |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 137 | if (info.depth_output_gemm3d != 0) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 138 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 139 | p.M = d->tensor_shape().y() * d->tensor_shape().z(); |
| 140 | p.batches = d->tensor_shape().total_size_upper(3) / p.multis; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 141 | } |
| 142 | |
| 143 | return p; |
| 144 | } |
| 145 | |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 146 | IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataType data_type) |
| 147 | { |
| 148 | // Schedule assembly kernel |
| 149 | const int granule_threshold = 200; |
| 150 | IScheduler::Hints scheduling_hint = IScheduler::Hints(Window::DimX); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 151 | if (method == arm_gemm::GemmMethod::GEMM_INTERLEAVED && data_type == DataType::F32) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 152 | { |
| 153 | scheduling_hint = IScheduler::Hints(Window::DimX, IScheduler::StrategyHint::DYNAMIC, granule_threshold); |
| 154 | } |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 155 | else if (method == arm_gemm::GemmMethod::GEMM_INTERLEAVED_2D && |
| 156 | (data_type == DataType::F32 || data_type == DataType::F16 || data_type == DataType::U8 || |
| 157 | data_type == DataType::S8)) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 158 | { |
| 159 | //GEMM_INTERLEAVED supports 2D parallelism, IScheduler::split_dimensions_all signals to parallelise over all window dimensions |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 160 | scheduling_hint = |
| 161 | IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 162 | } |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 163 | else if (method == arm_gemm::GemmMethod::QUANTIZE_WRAPPER_2D && |
| 164 | (data_type == DataType::QASYMM8 || data_type == DataType::QASYMM8_SIGNED)) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 165 | { |
| 166 | //special case for QASYMM8 to support 2D parallelism, scheduler here may be tweaked differently compared to FP32 case |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 167 | scheduling_hint = |
| 168 | IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 169 | } |
| 170 | |
| 171 | return scheduling_hint; |
| 172 | } |
| 173 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 174 | /** Fallback in case ACL doesn't have a function */ |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 175 | template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing> |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 176 | class Fallback : public CpuGemmAssemblyDispatch::IFallback |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 177 | { |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 178 | public: |
Michalis Spyrou | 1a569a3 | 2019-09-10 17:20:34 +0100 | [diff] [blame] | 179 | /** Destructor */ |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 180 | ~Fallback() = default; |
Michalis Spyrou | 1a569a3 | 2019-09-10 17:20:34 +0100 | [diff] [blame] | 181 | |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 182 | /** Initialise the functions's input and output. |
| 183 | * |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 184 | * @param[in] a Input tensor containing the Matrix A. |
| 185 | * @param[in] b Input tensor containing the Matrix B. |
| 186 | * @param[in] c Input tensor containing the Matrix C. |
| 187 | * @param[out] d Output tensor to store the result of matrix multiplication. |
| 188 | * @param[in] args Matrix multiplication information. |
| 189 | * @param[in] gemm_info GEMM meta-data |
| 190 | * @param[in] os Output stage meta-data. |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 191 | */ |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 192 | void configure(const ITensorInfo *a, |
| 193 | const ITensorInfo *b, |
| 194 | const ITensorInfo *c, |
| 195 | ITensorInfo *d, |
| 196 | arm_gemm::GemmArgs args, |
| 197 | const AsmGemmInfo &gemm_info, |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 198 | const OutputStage &os = {}); |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 199 | |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 200 | /** Set requantization shifts to be used |
| 201 | * |
| 202 | * @param[in] shifts Requantization shifts |
| 203 | * |
| 204 | * @return Pointer to the shift data |
| 205 | */ |
| 206 | /** Set requantization data to be used |
| 207 | * |
| 208 | * |
| 209 | * @param shifts Requantization shifts |
| 210 | * @param multipliers Requantization multipliers |
| 211 | * |
| 212 | * @return A tuple with the pointers to the shift and multiplier data respectively |
| 213 | */ |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 214 | std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> |
| 215 | set_requantize_data(const std::vector<int32_t> &shifts, const std::vector<int32_t> &multipliers); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 216 | |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 217 | // Inherited methods overridden: |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 218 | void run(ITensorPack &tensors) override; |
| 219 | void prepare(ITensorPack &tensors) override; |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 220 | bool is_configured() const override; |
| 221 | experimental::MemoryRequirements workspace() const override; |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 222 | bool isVarWeightsKernel() const override |
| 223 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 224 | if (!_gemm_kernel_asm) |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 225 | return false; |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 226 | const arm_compute::WeightFormat wf = |
| 227 | assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format); |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 228 | return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY; |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 229 | } |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 230 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 231 | private: |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 232 | enum AuxTensorIdx |
| 233 | { |
| 234 | AsmGemmWorkspace = 0, |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 235 | PrePretransposedB, /* Transposed B (rhs) before being passed to gemm or pretranspose_B_array */ |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 236 | Pretranspose, |
| 237 | Count |
| 238 | }; |
| 239 | |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 240 | /** Configure the indirect buffer |
| 241 | * |
| 242 | * @param[in] a Input tensor containing the Matrix A. |
| 243 | * @param[in] b Input tensor containing the Matrix B. |
| 244 | * @param[out] d Output tensor to store the result of matrix multiplication. |
| 245 | * @param[in] info GEMM meta-data |
| 246 | */ |
| 247 | void configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info); |
| 248 | /** Prepare the indirect buffer */ |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 249 | void prepare_indirect_buffer(ITensorPack &tensors); |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 250 | |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 251 | /** Operator to transpose B before gemm or pretranspose_B_array*/ |
| 252 | std::unique_ptr<CpuTranspose> _pre_pretranspose_b{nullptr}; |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 253 | /** Assembly Gemm kernel */ |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 254 | std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{nullptr}; |
Michele Di Giorgio | 33f41fa | 2021-03-09 14:09:08 +0000 | [diff] [blame] | 255 | /** Optimised Arm® Neon™ kernel */ |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 256 | std::unique_ptr<INEKernel> _optimised_kernel{nullptr}; |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 257 | /** Assembly GEMM workspace tensor info */ |
| 258 | TensorInfo _workspace_info{}; |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 259 | /** Pre-pre-transposed B tensor info */ |
| 260 | TensorInfo _pre_pretransposed_b_info{}; |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 261 | /** Pre-transpose tensor info */ |
| 262 | TensorInfo _pretranspose_info{}; |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 263 | /** Prepared flag */ |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 264 | bool _is_prepared{false}; |
Georgios Pinitas | 37d080f | 2019-06-21 18:43:12 +0100 | [diff] [blame] | 265 | /** GEMM meta-data */ |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 266 | AsmGemmInfo _gemm_info{}; |
Georgios Pinitas | 77d4252 | 2019-11-05 13:35:47 +0000 | [diff] [blame] | 267 | /** GEMM kernel description */ |
| 268 | arm_gemm::KernelDescription _kernel_info{}; |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 269 | /** Per channel quantization shifts */ |
| 270 | std::vector<int32_t> _shifts{}; |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 271 | std::vector<int32_t> right_shifts{}; |
| 272 | std::vector<int32_t> left_shifts{}; |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 273 | /** Per channel quantization multipliers */ |
| 274 | std::vector<int32_t> _multipliers{}; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 275 | /** Indirect buffer */ |
| 276 | std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{}; |
| 277 | std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{}; |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 278 | std::vector<TypeInput> _indirect_pad{}; |
| 279 | arm_gemm::ConvolutionParameters _cp{}; |
| 280 | experimental::MemoryRequirements _aux_mem{Count}; |
| 281 | bool _B_pretranspose_required{false}; |
| 282 | bool _is_b_constant{true}; |
| 283 | bool _is_c_constant{true}; |
Gunes Bayir | ef63739 | 2024-02-12 21:32:51 +0000 | [diff] [blame] | 284 | bool _run_pre_pretranspose_b{false}; |
| 285 | bool _B_pre_pretranspose_required{false}; |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 286 | }; |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 287 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 288 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 289 | std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 290 | Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, |
| 291 | const std::vector<int32_t> &multipliers) |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 292 | { |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 293 | _multipliers = multipliers; |
| 294 | _shifts = shifts; |
| 295 | bool need_left = false; |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 296 | for (const auto s : _shifts) |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 297 | { |
| 298 | left_shifts.push_back(std::max(-s, int32_t(0))); |
| 299 | right_shifts.push_back(std::min(-s, int32_t(0))); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 300 | if (s < 0 && !need_left) |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 301 | { |
| 302 | need_left = true; |
| 303 | } |
| 304 | } |
| 305 | return std::make_tuple(need_left, left_shifts.data(), right_shifts.data(), _multipliers.data()); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 306 | } |
| 307 | |
| 308 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 309 | void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 310 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 311 | auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0); |
| 312 | const TypeInput *A_ptr = reinterpret_cast<TypeInput *>(a->buffer()); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 313 | const int multis = 1; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 314 | const int batches = a->info()->tensor_shape().total_size_upper(3); |
| 315 | const size_t stride_A = a->info()->strides_in_bytes().y() / sizeof(TypeInput); |
| 316 | const size_t batch_stride_A = a->info()->strides_in_bytes()[3] / sizeof(TypeInput); |
| 317 | const size_t multi_stride_A = a->info()->strides_in_bytes()[4] / sizeof(TypeInput); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 318 | |
| 319 | const size_t output_hw = _cp.output_height * _cp.output_width; |
| 320 | const int batch_size = _cp.kernel_height * _cp.kernel_width * output_hw * sizeof(TypeInput); |
| 321 | const size_t batch_stride = batch_size / sizeof(TypeInput); |
| 322 | const int multi_size = batch_size * batches; |
| 323 | const size_t multi_stride = multi_size / sizeof(TypeInput); |
| 324 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 325 | for (int64_t m = 0; m < multis; m++) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 326 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 327 | for (int64_t b = 0; b < batches; b++) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 328 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 329 | for (int64_t output_y = 0; output_y < _cp.output_height; output_y++) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 330 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 331 | for (int64_t output_x = 0; output_x < _cp.output_width; output_x++) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 332 | { |
| 333 | int64_t output_xy = (output_y * _cp.output_width) + output_x; |
| 334 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 335 | for (int64_t kernel_y = 0; kernel_y < _cp.kernel_height; kernel_y++) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 336 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 337 | for (int64_t kernel_x = 0; kernel_x < _cp.kernel_width; kernel_x++) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 338 | { |
| 339 | int64_t input_x = (output_x * _cp.output_stride_w) + kernel_x - _cp.padding_left; |
| 340 | int64_t input_y = (output_y * _cp.output_stride_h) + kernel_y - _cp.padding_top; |
| 341 | int64_t kernel_xy = (kernel_y * _cp.kernel_width) + kernel_x; |
| 342 | int64_t input_xy = (input_y * _cp.input_width) + input_x; |
| 343 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 344 | if (input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 345 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 346 | _indirect_buf |
| 347 | .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = |
| 348 | _indirect_pad.data(); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 349 | } |
| 350 | else |
| 351 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 352 | _indirect_buf |
| 353 | .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 354 | A_ptr + (m * multi_stride_A + b * batch_stride_A + input_xy * stride_A); |
| 355 | } |
| 356 | } |
| 357 | } |
| 358 | } |
| 359 | } |
| 360 | } |
| 361 | } |
| 362 | } |
| 363 | |
| 364 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 365 | void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a, |
| 366 | const ITensorInfo *b, |
| 367 | const ITensorInfo *d, |
| 368 | const AsmGemmInfo &info) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 369 | { |
| 370 | ARM_COMPUTE_ERROR_ON(!(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)); |
| 371 | |
| 372 | float zeropad = 0.f; |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 373 | if (is_data_type_quantized(a->data_type())) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 374 | { |
| 375 | zeropad = a->quantization_info().uniform().offset; |
| 376 | } |
| 377 | |
| 378 | const int64_t input_width = static_cast<int64_t>(a->tensor_shape()[1]); |
| 379 | const int64_t input_height = static_cast<int64_t>(a->tensor_shape()[2]); |
| 380 | const int64_t input_channels = static_cast<int64_t>(a->tensor_shape()[0]); |
| 381 | const int64_t kernel_width = static_cast<int64_t>(b->tensor_shape()[2]); |
| 382 | const int64_t kernel_height = static_cast<int64_t>(b->tensor_shape()[3]); |
| 383 | const int64_t output_width = static_cast<int64_t>(d->tensor_shape()[1]); |
| 384 | const int64_t output_height = static_cast<int64_t>(d->tensor_shape()[2]); |
| 385 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 386 | _cp = {input_width, |
| 387 | input_height, |
| 388 | input_channels, |
| 389 | kernel_width, |
| 390 | kernel_height, |
| 391 | output_width, |
| 392 | output_height, |
| 393 | info.ps_info.stride().first, |
| 394 | info.ps_info.stride().second, |
| 395 | info.padding_top, |
| 396 | info.padding_left, |
| 397 | zeropad}; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 398 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 399 | if (info.method == AsmConvMethod::Conv) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 400 | { |
| 401 | _gemm_kernel_asm->set_convolution_parameters(_cp); |
| 402 | } |
| 403 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 404 | if (info.method == AsmConvMethod::Indirect) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 405 | { |
| 406 | const unsigned int multis = 1; |
| 407 | const unsigned int batches = a->tensor_shape().total_size_upper(3); |
| 408 | const unsigned int kernel_hw = _cp.kernel_width * _cp.kernel_height; |
| 409 | const unsigned int output_hw = _cp.output_width * _cp.output_height; |
| 410 | |
| 411 | using TypeInputPtr = TypeInput *; |
| 412 | const int batch_size = kernel_hw * output_hw * sizeof(TypeInputPtr); |
| 413 | const size_t batch_stride = batch_size / sizeof(TypeInputPtr); |
| 414 | const int multi_size = batch_size * batches; |
| 415 | const size_t multi_stride = multi_size / sizeof(TypeInputPtr); |
| 416 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 417 | _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>( |
| 418 | reinterpret_cast<const TypeInput **>(malloc(multi_size * multis))); |
| 419 | _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>( |
| 420 | reinterpret_cast<const TypeInput *const **>(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches))); |
Sang-Hoon Park | 8d5337e | 2021-01-15 14:36:25 +0000 | [diff] [blame] | 421 | _indirect_pad = std::vector<TypeInput>(_cp.input_channels, TypeInput(zeropad)); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 422 | |
| 423 | // Set indirect argument |
| 424 | int64_t pos = 0; |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 425 | for (int64_t m = 0; m < multis; m++) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 426 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 427 | for (int64_t b = 0; b < batches; b++) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 428 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 429 | for (int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 430 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 431 | (_indirect_arg.get())[pos++] = |
| 432 | _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 433 | } |
| 434 | } |
| 435 | } |
| 436 | |
| 437 | _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.get()); |
| 438 | } |
| 439 | } |
| 440 | |
| 441 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 442 | void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a, |
| 443 | const ITensorInfo *b, |
| 444 | const ITensorInfo *c, |
| 445 | ITensorInfo *d, |
| 446 | arm_gemm::GemmArgs args, |
| 447 | const AsmGemmInfo &gemm_info, |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 448 | const OutputStage &os) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 449 | { |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 450 | _is_b_constant = b->are_values_constant(); |
| 451 | _is_c_constant = c ? c->are_values_constant() : true; |
| 452 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 453 | _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput, OutputStage>(args, os); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 454 | if (_gemm_kernel_asm == nullptr) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 455 | { |
| 456 | //configuration not supported: Leave function unconfigured: |
| 457 | return; |
| 458 | } |
| 459 | |
Francesco.Petrogalli@arm.com | 193cad3 | 2022-03-07 13:39:21 +0000 | [diff] [blame] | 460 | arm_gemm::GemmConfig gemm_cfg = _gemm_kernel_asm->get_config(); |
| 461 | |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 462 | // arm_compute wrapper for the Gemm object (see above) |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 463 | auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>(); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 464 | ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr); |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 465 | acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter); |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 466 | const size_t workspace_size = _gemm_kernel_asm->get_working_size(); |
| 467 | const unsigned int alignment = 4096; |
| 468 | _workspace_info = TensorInfo(TensorShape(workspace_size), 1, DataType::U8); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 469 | _aux_mem[AsmGemmWorkspace] = |
| 470 | MemoryInfo(offset_int_vec(AsmGemmWorkspace), MemoryLifetime::Temporary, workspace_size, alignment); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 471 | |
| 472 | //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and |
| 473 | //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001 |
| 474 | { |
Georgios Pinitas | 5aa1a0b | 2020-07-02 20:02:20 +0100 | [diff] [blame] | 475 | const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size(); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 476 | if (window_size < static_cast<unsigned int>(args._maxthreads)) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 477 | { |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 478 | _gemm_kernel_asm->set_nthreads(window_size); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 479 | } |
| 480 | } |
| 481 | |
| 482 | _optimised_kernel = std::move(acl_gemm_wrapper); |
Georgios Pinitas | 37d080f | 2019-06-21 18:43:12 +0100 | [diff] [blame] | 483 | _gemm_info = gemm_info; |
Gunes Bayir | ef63739 | 2024-02-12 21:32:51 +0000 | [diff] [blame] | 484 | |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 485 | // Check if we need to pre-pretranspose B. Fixed format kernels need no pre-pretranspose. |
Gunes Bayir | ef63739 | 2024-02-12 21:32:51 +0000 | [diff] [blame] | 486 | _B_pre_pretranspose_required = _gemm_info.transpose_b && !isVarWeightsKernel(); |
| 487 | _B_pretranspose_required = _gemm_kernel_asm->B_pretranspose_required(); |
| 488 | |
| 489 | const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose(); |
| 490 | const bool kernel_can_fuse_transpose = _B_pretranspose_required && kernel_supports_transpose; |
| 491 | _run_pre_pretranspose_b = _B_pre_pretranspose_required && !kernel_can_fuse_transpose; |
| 492 | |
| 493 | if (_run_pre_pretranspose_b) |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 494 | { |
| 495 | _pre_pretranspose_b = std::make_unique<CpuTranspose>(); |
| 496 | _pre_pretranspose_b->configure(b, &_pre_pretransposed_b_info); |
| 497 | MemoryLifetime lifetime; |
| 498 | if (_is_b_constant) |
| 499 | { |
Gunes Bayir | ef63739 | 2024-02-12 21:32:51 +0000 | [diff] [blame] | 500 | if (_B_pretranspose_required) |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 501 | { |
| 502 | // PrePretransposedB tensor is only used in prepare(), but is then succeeded by Pretranspose |
| 503 | // So PrePretransposedB can be freed inside prepare() |
| 504 | lifetime = MemoryLifetime::Prepare; |
| 505 | } |
| 506 | else |
| 507 | { |
| 508 | // PrePretransposedB tensor is only used in prepare(), but is the final transformation of B |
| 509 | // So PrePretransposedB needs to persist beyond prepare() |
| 510 | lifetime = MemoryLifetime::Persistent; |
| 511 | } |
| 512 | } |
| 513 | else |
| 514 | { |
| 515 | // PrePretransposedB tensor is always used in run() and doesn't need to persist |
| 516 | lifetime = MemoryLifetime::Temporary; |
| 517 | } |
| 518 | // Forcing 128-byte alignment (required by 32-bit kernels) |
| 519 | const unsigned int alignment = 128; |
| 520 | _aux_mem[PrePretransposedB] = |
| 521 | MemoryInfo(offset_int_vec(PrePretransposedB), lifetime, _pre_pretransposed_b_info.total_size(), alignment); |
| 522 | } |
| 523 | |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 524 | // Check for pre-transposed support |
Gunes Bayir | ef63739 | 2024-02-12 21:32:51 +0000 | [diff] [blame] | 525 | if (_B_pretranspose_required) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 526 | { |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 527 | // Fixed format kernels need no pretranspose. |
| 528 | ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format( |
| 529 | assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format))); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 530 | // Forcing 128-byte alignment (required by 32-bit kernels) |
| 531 | const unsigned int alignment = 128; |
| 532 | const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size(); |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 533 | _pretranspose_info = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 534 | _aux_mem[Pretranspose] = |
| 535 | MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 536 | } |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 537 | |
| 538 | // Handle indirect GEMM convolution |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 539 | if (gemm_info.method == AsmConvMethod::Conv || gemm_info.method == AsmConvMethod::Indirect) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 540 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 541 | configure_indirect(a, b, d, gemm_info); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 542 | } |
Jonathan Deakin | a668f9f | 2024-01-24 09:15:38 +0000 | [diff] [blame^] | 543 | |
| 544 | if (std::is_same<OutputStage, arm_gemm::DequantizeFloat>::value) |
| 545 | { |
| 546 | // Output dequantization is just the two src scales multiplied together |
| 547 | _gemm_kernel_asm->set_dequantize_scale(a->quantization_info().uniform().scale * |
| 548 | b->quantization_info().uniform().scale); |
| 549 | } |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 550 | } |
| 551 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 552 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 553 | void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 554 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 555 | if (!_is_prepared) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 556 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 557 | auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1); |
| 558 | auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2); |
SiCong Li | d4650e9 | 2023-11-14 15:17:10 +0000 | [diff] [blame] | 559 | ARM_COMPUTE_ERROR_ON_NULLPTR(b); |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 560 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 561 | // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C. |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 562 | if (c && c->info()->data_type() == DataType::S32) |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 563 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 564 | _gemm_kernel_asm->set_quantized_bias( |
| 565 | reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 566 | } |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 567 | const ITensor *b_to_use = b; |
Gunes Bayir | ef63739 | 2024-02-12 21:32:51 +0000 | [diff] [blame] | 568 | |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 569 | // Pre-pretranspose B if required |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 570 | CpuAuxTensorHandler pre_pretransposed_b( |
| 571 | offset_int_vec(PrePretransposedB), _pre_pretransposed_b_info, tensors, |
| 572 | /*pack_inject: no need to inject into tensors*/ |
| 573 | false, |
| 574 | /*bypass_alloc: no need to allocate if pre-pretranspose B is not required as this handle will not be used*/ |
Gunes Bayir | ef63739 | 2024-02-12 21:32:51 +0000 | [diff] [blame] | 575 | !_run_pre_pretranspose_b); |
| 576 | |
| 577 | if (_run_pre_pretranspose_b) |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 578 | { |
| 579 | ARM_COMPUTE_ERROR_ON(_pre_pretranspose_b == nullptr); |
| 580 | ITensorPack pre_pretranspose_pack{{ACL_SRC, b_to_use}, {ACL_DST, pre_pretransposed_b.get()}}; |
| 581 | _pre_pretranspose_b->run(pre_pretranspose_pack); |
| 582 | b_to_use = pre_pretransposed_b.get(); |
| 583 | } |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 584 | |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 585 | // Pretranspose B if required |
Gunes Bayir | ef63739 | 2024-02-12 21:32:51 +0000 | [diff] [blame] | 586 | if (_B_pretranspose_required) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 587 | { |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 588 | // Fixed format kernels need no pretranspose. |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 589 | ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format( |
| 590 | assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format))); |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 591 | const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); |
| 592 | const auto in1_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() + |
| 593 | b_to_use->info()->offset_first_element_in_bytes()); |
| 594 | const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size(); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 595 | |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 596 | CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false); |
Gunes Bayir | ef63739 | 2024-02-12 21:32:51 +0000 | [diff] [blame] | 597 | |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 598 | ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr); |
Gunes Bayir | ef63739 | 2024-02-12 21:32:51 +0000 | [diff] [blame] | 599 | |
| 600 | const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose(); |
| 601 | run_parallel_pretranspose_B_array<TypeInput, TypeOutput>( |
| 602 | _gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b, |
| 603 | NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose); |
Georgios Pinitas | fa1db17 | 2021-08-12 06:28:09 +0100 | [diff] [blame] | 604 | |
| 605 | b->mark_as_unused(); |
Gunes Bayir | ef63739 | 2024-02-12 21:32:51 +0000 | [diff] [blame] | 606 | // Note that we don't need to mark b_to_use as unused, as if it's been assigned to pre_pretransposed_b, |
| 607 | // its memory will be auto-managed by the handler |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 608 | } |
| 609 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 610 | if (_gemm_info.method == AsmConvMethod::Indirect) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 611 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 612 | prepare_indirect_buffer(tensors); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 613 | } |
| 614 | |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 615 | _is_prepared = true; |
| 616 | } |
| 617 | } |
| 618 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 619 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 620 | bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 621 | { |
| 622 | return _optimised_kernel != nullptr; |
| 623 | } |
| 624 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 625 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 626 | experimental::MemoryRequirements Fallback<TypeInput, TypeOutput, OutputStage>::workspace() const |
| 627 | { |
| 628 | return _aux_mem; |
| 629 | } |
| 630 | |
| 631 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 632 | void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 633 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 634 | auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0); |
| 635 | auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1); |
| 636 | auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2); |
| 637 | auto d = tensors.get_tensor(TensorType::ACL_DST); |
SiCong Li | d4650e9 | 2023-11-14 15:17:10 +0000 | [diff] [blame] | 638 | ARM_COMPUTE_ERROR_ON_NULLPTR(a, d); |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 639 | |
Jonathan Deakin | a668f9f | 2024-01-24 09:15:38 +0000 | [diff] [blame^] | 640 | // Only update at runtime if the src quantization is dynamic |
| 641 | if (std::is_same<OutputStage, arm_gemm::DequantizeFloat>::value && |
| 642 | (a->info()->quantization_info().is_dynamic() || b->info()->quantization_info().is_dynamic())) |
| 643 | { |
| 644 | // Output dequantization is just the two src scales multiplied together |
| 645 | _gemm_kernel_asm->set_dequantize_scale(a->info()->quantization_info().uniform().scale * |
| 646 | b->info()->quantization_info().uniform().scale); |
| 647 | } |
| 648 | |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 649 | int lda = a->info()->strides_in_bytes().y() / a->info()->element_size(); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 650 | int ldb = 0; |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 651 | const int ldd = d->info()->strides_in_bytes().y() / d->info()->element_size(); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 652 | |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 653 | const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d != 0 ? 3 : 2; |
Georgios Pinitas | 37d080f | 2019-06-21 18:43:12 +0100 | [diff] [blame] | 654 | const size_t a_multi_idx = a_batch_idx + 1; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 655 | const size_t d_batch_idx = _gemm_info.depth_output_gemm3d != 0 ? 3 : 2; |
Georgios Pinitas | 37d080f | 2019-06-21 18:43:12 +0100 | [diff] [blame] | 656 | const size_t d_multi_idx = d_batch_idx + 1; |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 657 | |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 658 | int batch_stride_a = a->info()->strides_in_bytes()[a_batch_idx] / a->info()->element_size(); |
| 659 | const int batch_stride_d = d->info()->strides_in_bytes()[d_batch_idx] / d->info()->element_size(); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 660 | |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 661 | int multi_stride_a = a->info()->strides_in_bytes()[a_multi_idx] / a->info()->element_size(); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 662 | int multi_stride_b = 0; |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 663 | const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / d->info()->element_size(); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 664 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 665 | auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes()); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 666 | const TypeInput *in1_ptr = nullptr; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 667 | auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes()); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 668 | |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 669 | const ITensor *b_to_use = b; |
| 670 | |
| 671 | // Pre-pretranspose B if required |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 672 | CpuAuxTensorHandler pre_pretransposed_b( |
| 673 | offset_int_vec(PrePretransposedB), _pre_pretransposed_b_info, tensors, |
| 674 | false /*pack_inject: no need to inject into tensors*/, |
Gunes Bayir | ef63739 | 2024-02-12 21:32:51 +0000 | [diff] [blame] | 675 | !_run_pre_pretranspose_b /*bypass_alloc: no need to allocate if pre-pretranspose B is not required as this handle will not be used*/); |
| 676 | if (b_to_use && !_is_b_constant && _run_pre_pretranspose_b) |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 677 | { |
| 678 | ARM_COMPUTE_ERROR_ON(_pre_pretranspose_b == nullptr); |
| 679 | ITensorPack pre_pretranspose_pack{{ACL_SRC, b_to_use}, {ACL_DST, pre_pretransposed_b.get()}}; |
| 680 | _pre_pretranspose_b->run(pre_pretranspose_pack); |
| 681 | b_to_use = pre_pretransposed_b.get(); |
| 682 | } |
| 683 | |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 684 | // Check if B is pre-tranposed and de-reference if not |
SiCong Li | d4650e9 | 2023-11-14 15:17:10 +0000 | [diff] [blame] | 685 | if (b_to_use && !_gemm_kernel_asm->B_is_pretransposed()) |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 686 | { |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 687 | ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); |
| 688 | multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size(); |
| 689 | in1_ptr = |
| 690 | reinterpret_cast<const TypeInput *>(b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes()); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 691 | } |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 692 | |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 693 | // If necessary, run pretranspose every time if either weights or biases are non-constant |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 694 | if ((b_to_use && !_is_b_constant) || (c && !_is_c_constant && c->info()->data_type() == DataType::S32)) |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 695 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 696 | if (c && c->info()->data_type() == DataType::S32) |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 697 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 698 | _gemm_kernel_asm->set_quantized_bias( |
| 699 | reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0); |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 700 | } |
| 701 | |
| 702 | // Pretranspose B if required |
SiCong Li | d4650e9 | 2023-11-14 15:17:10 +0000 | [diff] [blame] | 703 | if (b_to_use && _B_pretranspose_required) |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 704 | { |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 705 | // Fixed format kernels need no pretranspose. |
| 706 | ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format( |
| 707 | assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format))); |
| 708 | const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); |
| 709 | const auto b_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() + |
| 710 | b_to_use->info()->offset_first_element_in_bytes()); |
| 711 | const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size(); |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 712 | |
| 713 | CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true); |
| 714 | ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr); |
| 715 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 716 | if (_is_b_constant) |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 717 | { |
| 718 | _gemm_kernel_asm->requantize_bias(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b); |
| 719 | } |
| 720 | else |
| 721 | { |
Gunes Bayir | ef63739 | 2024-02-12 21:32:51 +0000 | [diff] [blame] | 722 | const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose(); |
| 723 | run_parallel_pretranspose_B_array<TypeInput, TypeOutput>( |
| 724 | _gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b, |
| 725 | NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose); |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 726 | } |
| 727 | } |
| 728 | } |
| 729 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 730 | const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, d->info()->data_type()); |
Joseph Dobson | 6f8b17d | 2020-02-11 19:32:11 +0000 | [diff] [blame] | 731 | |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 732 | // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 733 | CpuAuxTensorHandler workspace(offset_int_vec(AsmGemmWorkspace), _workspace_info, tensors, false); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 734 | if (workspace.get()->buffer() != nullptr) |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 735 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 736 | _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(workspace.get()->buffer())); |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 737 | const unsigned int split_dim = scheduling_hint.split_dimension(); |
| 738 | const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size(); |
Pablo Marquez Tello | 17e116e | 2023-12-05 15:44:50 +0000 | [diff] [blame] | 739 | unsigned int num_threads = NEScheduler::get().num_threads(); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 740 | if (window_size < num_threads) |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 741 | { |
| 742 | num_threads = window_size; |
| 743 | } |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 744 | if (split_dim != IScheduler::split_dimensions_all) |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 745 | { |
| 746 | // Make sure the kernel does not expect more threads than we can actually spawn |
| 747 | const unsigned int num_iterations = _optimised_kernel.get()->window().num_iterations(split_dim); |
| 748 | num_threads = std::min(num_iterations, num_threads); |
| 749 | } |
| 750 | _gemm_kernel_asm->set_nthreads(num_threads); |
| 751 | } |
| 752 | |
| 753 | // Prepare assembly kernel |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 754 | prepare(tensors); |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 755 | |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 756 | // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C. |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 757 | TypeOutput *bias = nullptr; |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 758 | if (c && c->info()->data_type() != DataType::S32) |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 759 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 760 | bias = reinterpret_cast<TypeOutput *>(c->buffer() + c->info()->offset_first_element_in_bytes()); |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 761 | } |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 762 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 763 | if (_gemm_info.method == AsmConvMethod::Indirect) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 764 | { |
| 765 | in0_ptr = nullptr; |
| 766 | lda = 0; |
| 767 | batch_stride_a = 0; |
| 768 | multi_stride_a = 0; |
| 769 | } |
| 770 | |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 771 | // Set gemm parameters |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 772 | _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, |
| 773 | ldd, batch_stride_d, multi_stride_d, bias, 0); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 774 | // Schedule |
Georgios Pinitas | 77d4252 | 2019-11-05 13:35:47 +0000 | [diff] [blame] | 775 | NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 776 | } |
| 777 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 778 | template <typename TypeInput, typename TypeOutput> |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 779 | void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 780 | const ITensorInfo *a, |
| 781 | const ITensorInfo *b, |
| 782 | const ITensorInfo *c, |
| 783 | ITensorInfo *d, |
| 784 | arm_gemm::Activation activation, |
| 785 | const AsmGemmInfo &info) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 786 | { |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 787 | Params p = extract_parameters(a, b, d, info); |
Pablo Marquez Tello | 17e116e | 2023-12-05 15:44:50 +0000 | [diff] [blame] | 788 | const CPUInfo &ci = NEScheduler::get().cpu_info(); |
| 789 | unsigned int num_threads = NEScheduler::get().num_threads(); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 790 | |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 791 | arm_gemm::GemmConfig cfg; |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 792 | cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 793 | arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, |
Radu Salavat | f1f1f87 | 2024-02-27 18:32:26 +0000 | [diff] [blame] | 794 | info.fixed_format, info.fast_mode, info.accumulate, &cfg); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 795 | |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 796 | // Create arm_gemm fallback |
Georgios Pinitas | 40f51a6 | 2020-11-21 03:04:18 +0000 | [diff] [blame] | 797 | auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>(); |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 798 | fallback->configure(a, b, c, d, args, info); |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 799 | arm_gemm = std::move(fallback); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 800 | } |
| 801 | |
| 802 | template <typename TypeInput, typename TypeOutput> |
Jonathan Deakin | a668f9f | 2024-01-24 09:15:38 +0000 | [diff] [blame^] | 803 | void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, |
| 804 | const ITensorInfo *a, |
| 805 | const ITensorInfo *b, |
| 806 | const ITensorInfo *c, |
| 807 | ITensorInfo *d, |
| 808 | arm_gemm::Activation activation, |
| 809 | const AsmGemmInfo &info) |
| 810 | { |
| 811 | ARM_COMPUTE_UNUSED(activation); |
| 812 | |
| 813 | Params p = extract_parameters(a, b, d, info); |
| 814 | const CPUInfo &ci = NEScheduler::get().cpu_info(); |
| 815 | const unsigned int num_threads = NEScheduler::get().num_threads(); |
| 816 | |
| 817 | arm_gemm::GemmConfig cfg; |
| 818 | cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); |
| 819 | arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, |
| 820 | info.fixed_format, info.fast_mode, info.accumulate, &cfg); |
| 821 | |
| 822 | // Create arm_gemm fallback |
| 823 | auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::DequantizeFloat>>(); |
| 824 | |
| 825 | // Configure requantization info |
| 826 | const GEMMLowpOutputStageInfo os_info = info.output_stage; |
| 827 | |
| 828 | arm_gemm::DequantizeFloat gemm_dequant_info{}; |
| 829 | gemm_dequant_info = arm_gemm::DequantizeFloat(d->quantization_info().uniform().scale); |
| 830 | |
| 831 | fallback->configure(a, b, c, d, args, info, gemm_dequant_info); |
| 832 | arm_gemm = std::move(fallback); |
| 833 | } |
| 834 | |
| 835 | template <typename TypeInput, typename TypeOutput> |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 836 | void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 837 | const ITensorInfo *a, |
| 838 | const ITensorInfo *b, |
| 839 | const ITensorInfo *c, |
| 840 | ITensorInfo *d, |
| 841 | arm_gemm::Activation activation, |
| 842 | const AsmGemmInfo &info) |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 843 | { |
Michele Di Giorgio | 6ad60af | 2020-06-09 14:52:15 +0100 | [diff] [blame] | 844 | ARM_COMPUTE_UNUSED(activation); |
Georgios Pinitas | 4ee8b15 | 2021-07-16 16:16:43 +0100 | [diff] [blame] | 845 | Params p = extract_parameters(a, b, d, info); |
Pablo Marquez Tello | 17e116e | 2023-12-05 15:44:50 +0000 | [diff] [blame] | 846 | const CPUInfo &ci = NEScheduler::get().cpu_info(); |
| 847 | const unsigned int num_threads = NEScheduler::get().num_threads(); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 848 | |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 849 | arm_gemm::GemmConfig cfg; |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 850 | cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 851 | arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, |
Radu Salavat | f1f1f87 | 2024-02-27 18:32:26 +0000 | [diff] [blame] | 852 | info.fixed_format, info.fast_mode, info.accumulate, &cfg); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 853 | |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 854 | // Create arm_gemm fallback |
Georgios Pinitas | 40f51a6 | 2020-11-21 03:04:18 +0000 | [diff] [blame] | 855 | auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>(); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 856 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 857 | // Configure requantization info |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 858 | const int32_t negation = info.negated_offsets ? 1 : -1; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 859 | const int32_t a_offset = -a->quantization_info().uniform().offset * negation; |
| 860 | const int32_t b_offset = -b->quantization_info().uniform().offset * negation; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 861 | const GEMMLowpOutputStageInfo os_info = info.output_stage; |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 862 | |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 863 | arm_gemm::Requantize32 gemm_requant_info{}; |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 864 | if (os_info.gemmlowp_shifts.size() > 1) |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 865 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 866 | const auto requantize_data = |
| 867 | fallback->set_requantize_data(os_info.gemmlowp_shifts, os_info.gemmlowp_multipliers); |
| 868 | gemm_requant_info = arm_gemm::Requantize32( |
| 869 | nullptr, 0, a_offset, b_offset, os_info.gemmlowp_offset, |
| 870 | (std::get<0>(requantize_data)) ? std::get<1>(requantize_data) : nullptr, std::get<2>(requantize_data), |
| 871 | std::get<3>(requantize_data), os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 872 | } |
| 873 | else |
| 874 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 875 | gemm_requant_info = |
| 876 | arm_gemm::Requantize32(nullptr, 0, a_offset, b_offset, os_info.gemmlowp_offset, -os_info.gemmlowp_shift, |
| 877 | os_info.gemmlowp_multiplier, os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 878 | } |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 879 | |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 880 | // Configure fallback |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 881 | fallback->configure(a, b, c, d, args, info, gemm_requant_info); |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 882 | arm_gemm = std::move(fallback); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 883 | } |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 884 | } //namespace |
| 885 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 886 | CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch() : _arm_gemm(nullptr) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 887 | { |
| 888 | } |
| 889 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 890 | Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, |
| 891 | const ITensorInfo *a, |
| 892 | const ITensorInfo *b, |
| 893 | const ITensorInfo *c, |
| 894 | const ITensorInfo *d, |
| 895 | const AsmGemmInfo &info) |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 896 | { |
| 897 | ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); |
| 898 | ARM_COMPUTE_UNUSED(c); |
| 899 | arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(info.activation_info); |
| 900 | Params p = extract_parameters(a, b, d, info); |
Pablo Marquez Tello | 17e116e | 2023-12-05 15:44:50 +0000 | [diff] [blame] | 901 | const CPUInfo &ci = NEScheduler::get().cpu_info(); |
| 902 | unsigned int num_threads = NEScheduler::get().num_threads(); |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 903 | arm_gemm::GemmConfig cfg; |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 904 | cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); |
| 905 | arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 906 | arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, |
Radu Salavat | f1f1f87 | 2024-02-27 18:32:26 +0000 | [diff] [blame] | 907 | info.fixed_format, info.fast_mode, info.accumulate, &cfg); |
SiCong Li | c5ab4df | 2023-10-17 17:38:57 +0100 | [diff] [blame] | 908 | // TODO: Incorporate info.transpose_b COMPMID-6595 |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 909 | switch (a->data_type()) |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 910 | { |
| 911 | case DataType::F32: |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 912 | ARM_COMPUTE_RETURN_ERROR_ON_MSG( |
| 913 | !(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), |
| 914 | "We could not find an optimized kernel for F32 input"); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 915 | break; |
| 916 | #ifdef __aarch64__ |
| 917 | case DataType::U8: |
| 918 | case DataType::QASYMM8: |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 919 | if (d->data_type() == DataType::S32) |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 920 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 921 | ARM_COMPUTE_RETURN_ERROR_ON_MSG( |
| 922 | !(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), |
| 923 | "We could not find an optimized kernel for U8/QASYMM8 input and U32 output"); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 924 | } |
| 925 | else |
| 926 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 927 | ARM_COMPUTE_RETURN_ERROR_ON_MSG( |
| 928 | !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})), |
| 929 | "We could not find an optimized kernel for U8 input and U8 output"); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 930 | } |
| 931 | break; |
| 932 | case DataType::S8: |
| 933 | case DataType::QASYMM8_SIGNED: |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 934 | if (d->data_type() == DataType::S32) |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 935 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 936 | ARM_COMPUTE_RETURN_ERROR_ON_MSG( |
| 937 | !(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), |
| 938 | "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output"); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 939 | } |
| 940 | else |
| 941 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 942 | ARM_COMPUTE_RETURN_ERROR_ON_MSG( |
| 943 | !(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})), |
| 944 | "We could not find an optimized kernel for S8 input and S8 output"); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 945 | } |
| 946 | break; |
| 947 | #endif /* __aarch64__ */ |
Pablo Marquez Tello | d208f4f | 2022-07-19 12:19:46 +0100 | [diff] [blame] | 948 | #if defined(ARM_COMPUTE_ENABLE_BF16) |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 949 | case DataType::BFLOAT16: |
| 950 | { |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame] | 951 | if (d->data_type() == DataType::BFLOAT16) |
| 952 | { |
| 953 | ARM_COMPUTE_RETURN_ERROR_ON_MSG( |
| 954 | !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), |
| 955 | "We could not find an optimized kernel for BFLOAT16 input and BFLOAT16 output"); |
| 956 | } |
| 957 | else |
| 958 | { |
| 959 | ARM_COMPUTE_RETURN_ERROR_ON_MSG( |
| 960 | !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), |
| 961 | "We could not find an optimized kernel for BFLOAT16 input and F32 output"); |
| 962 | } |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 963 | break; |
| 964 | } |
Pablo Marquez Tello | d208f4f | 2022-07-19 12:19:46 +0100 | [diff] [blame] | 965 | #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 966 | #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC |
| 967 | case DataType::F16: |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 968 | ARM_COMPUTE_RETURN_ERROR_ON_MSG( |
| 969 | !(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), |
| 970 | "We could not find an optimized kernel for F16 input and F16 output"); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 971 | break; |
| 972 | #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ |
| 973 | default: |
| 974 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel"); |
| 975 | break; |
| 976 | } |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 977 | expected_weight_format = assembly_utils::map_to_arm_compute_weight_format(arm_gemm_expected_wf); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 978 | |
| 979 | return Status{}; |
| 980 | } |
| 981 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 982 | Status CpuGemmAssemblyDispatch::validate( |
| 983 | const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 984 | { |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 985 | ARM_COMPUTE_UNUSED(c, info); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 986 | ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d); |
| 987 | ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 988 | ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 989 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(info.reshape_b_only_on_first_run), |
| 990 | "Assembly kernel will not be executed when reshape_b_only_on_first_run is false"); |
Georgios Pinitas | 0f954eb | 2020-06-23 17:28:38 +0100 | [diff] [blame] | 991 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 992 | #ifndef __aarch64__ |
Michele Di Giorgio | 5255672 | 2019-12-23 16:35:12 +0000 | [diff] [blame] | 993 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64"); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 994 | #endif /* __aarch64__ */ |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 995 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::U8, DataType::QASYMM8, |
| 996 | DataType::QASYMM8_SIGNED, DataType::S8, DataType::BFLOAT16, |
| 997 | DataType::F16, DataType::F32); |
| 998 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN( |
| 999 | b, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::S8, |
| 1000 | DataType::BFLOAT16, DataType::F16, DataType::F32); |
| 1001 | if (is_data_type_quantized_per_channel(b->data_type())) |
Georgios Pinitas | dbdea0d | 2019-10-16 19:21:40 +0100 | [diff] [blame] | 1002 | { |
| 1003 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QASYMM8_SIGNED, DataType::S8); |
| 1004 | } |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 1005 | else if (is_fixed_format_fast_math(info.weight_format)) |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 1006 | { |
| 1007 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32); |
| 1008 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16); |
| 1009 | } |
Georgios Pinitas | dbdea0d | 2019-10-16 19:21:40 +0100 | [diff] [blame] | 1010 | else |
| 1011 | { |
| 1012 | ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b); |
| 1013 | } |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 1014 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F32 && d->data_type() != DataType::F32, |
| 1015 | "Only F32 output supported for F32 input"); |
| 1016 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16, |
| 1017 | "Only F16 output supported for F16 input"); |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame] | 1018 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 && |
| 1019 | (d->data_type() != DataType::F32 && d->data_type() != DataType::BFLOAT16), |
| 1020 | "Only F32/BFLOAT16 output supported for BFLOAT16 input"); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 1021 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, |
| 1022 | "Only U32 output supported for U8 input"); |
| 1023 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, |
| 1024 | "Only S32 output supported for S8 input"); |
| 1025 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && |
| 1026 | (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32), |
Ethan Doe | 1fe48ca | 2023-03-01 23:19:26 +0000 | [diff] [blame] | 1027 | "Only QASYMM8/S32 output supported for QASYMM8 input"); |
Viet-Hoa Do | 246fe08 | 2023-08-16 10:29:00 +0100 | [diff] [blame] | 1028 | arm_compute::WeightFormat expected_weight_format = arm_compute::WeightFormat::UNSPECIFIED; |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 1029 | const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 1030 | if ((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY) |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 1031 | { |
| 1032 | // Correctness check: if the format expected by the kernel is |
| 1033 | // not "any", make sure that the one found matches the format |
| 1034 | // intended by the caller. |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 1035 | ARM_COMPUTE_RETURN_ERROR_ON_MSG( |
| 1036 | (expected_weight_format != info.weight_format), |
| 1037 | "The format expected by the kernel does not correspond with the one requested by the user."); |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 1038 | } |
| 1039 | return ret; |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 1040 | } |
| 1041 | |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 1042 | bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation) |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 1043 | { |
Michele Di Giorgio | d02d5ed | 2021-01-22 09:47:04 +0000 | [diff] [blame] | 1044 | arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(activation); |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 1045 | return act.type != arm_gemm::Activation::Type::None; |
| 1046 | } |
| 1047 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 1048 | void CpuGemmAssemblyDispatch::configure( |
| 1049 | const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, const AsmGemmInfo &info) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 1050 | { |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 1051 | ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); |
Michele Di Giorgio | d02d5ed | 2021-01-22 09:47:04 +0000 | [diff] [blame] | 1052 | arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(info.activation_info); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 1053 | |
| 1054 | //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured() |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 1055 | if (!CpuGemmAssemblyDispatch::validate(a, b, c, d, info)) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 1056 | { |
| 1057 | return; |
| 1058 | } |
| 1059 | |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 1060 | switch (a->data_type()) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 1061 | { |
| 1062 | case DataType::F32: |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 1063 | create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 1064 | break; |
| 1065 | #ifdef __aarch64__ |
| 1066 | case DataType::U8: |
| 1067 | case DataType::QASYMM8: |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 1068 | if (d->data_type() == DataType::S32) |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 1069 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 1070 | create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 1071 | } |
| 1072 | else |
| 1073 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 1074 | create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 1075 | } |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 1076 | break; |
| 1077 | case DataType::S8: |
Georgios Pinitas | dbdea0d | 2019-10-16 19:21:40 +0100 | [diff] [blame] | 1078 | case DataType::QASYMM8_SIGNED: |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 1079 | if (d->data_type() == DataType::S32) |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 1080 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 1081 | create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 1082 | } |
Jonathan Deakin | a668f9f | 2024-01-24 09:15:38 +0000 | [diff] [blame^] | 1083 | else if (d->data_type() == DataType::F32) |
| 1084 | { |
| 1085 | create_arm_gemm_dequant<int8_t, float>(_arm_gemm, a, b, c, d, act, info); |
| 1086 | } |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 1087 | else |
| 1088 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 1089 | create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 1090 | } |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 1091 | break; |
| 1092 | #endif /* __aarch64__ */ |
Pablo Marquez Tello | d208f4f | 2022-07-19 12:19:46 +0100 | [diff] [blame] | 1093 | #if defined(ARM_COMPUTE_ENABLE_BF16) |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 1094 | case DataType::BFLOAT16: |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame] | 1095 | if (d->data_type() == DataType::BFLOAT16) |
| 1096 | { |
| 1097 | create_arm_gemm<bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info); |
| 1098 | } |
| 1099 | else |
| 1100 | { |
| 1101 | create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info); |
| 1102 | } |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 1103 | break; |
Pablo Marquez Tello | d208f4f | 2022-07-19 12:19:46 +0100 | [diff] [blame] | 1104 | #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 1105 | #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC |
| 1106 | case DataType::F16: |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 1107 | create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 1108 | break; |
| 1109 | #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ |
| 1110 | default: |
| 1111 | break; |
| 1112 | } |
| 1113 | } |
| 1114 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 1115 | void CpuGemmAssemblyDispatch::prepare(ITensorPack &tensors) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 1116 | { |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 1117 | ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 1118 | _arm_gemm->prepare(tensors); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 1119 | } |
| 1120 | |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 1121 | bool CpuGemmAssemblyDispatch::is_configured() const |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 1122 | { |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 1123 | return _arm_gemm && _arm_gemm->is_configured(); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 1124 | } |
| 1125 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 1126 | void CpuGemmAssemblyDispatch::run(ITensorPack &tensors) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 1127 | { |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 1128 | ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 1129 | _arm_gemm->run(tensors); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 1130 | } |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 1131 | |
| 1132 | experimental::MemoryRequirements CpuGemmAssemblyDispatch::workspace() const |
| 1133 | { |
| 1134 | ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); |
| 1135 | return _arm_gemm->workspace(); |
| 1136 | } |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 1137 | } // namespace cpu |
| 1138 | } // namespace arm_compute |