Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 1 | /* |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 2 | * Copyright (c) 2018-2023 Arm Limited. |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 3 | * |
| 4 | * SPDX-License-Identifier: MIT |
| 5 | * |
| 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | * of this software and associated documentation files (the "Software"), to |
| 8 | * deal in the Software without restriction, including without limitation the |
| 9 | * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 10 | * sell copies of the Software, and to permit persons to whom the Software is |
| 11 | * furnished to do so, subject to the following conditions: |
| 12 | * |
| 13 | * The above copyright notice and this permission notice shall be included in all |
| 14 | * copies or substantial portions of the Software. |
| 15 | * |
| 16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | * SOFTWARE. |
| 23 | */ |
Georgios Pinitas | 7891a73 | 2021-08-20 21:39:25 +0100 | [diff] [blame] | 24 | #include "src/cpu/operators/internal/CpuGemmAssemblyDispatch.h" |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 25 | |
Sang-Hoon Park | 68dd25f | 2020-10-19 16:00:11 +0100 | [diff] [blame] | 26 | #include "arm_compute/runtime/NEON/NEScheduler.h" |
| 27 | #include "src/core/CPP/Validate.h" |
Pablo Marquez Tello | 93581a5 | 2022-07-21 13:55:27 +0100 | [diff] [blame] | 28 | #include "src/core/NEON/kernels/arm_gemm/utils.hpp" |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 29 | #include "src/core/helpers/MemoryHelpers.h" |
Michele Di Giorgio | d02d5ed | 2021-01-22 09:47:04 +0000 | [diff] [blame] | 30 | #include "src/core/utils/AssemblyUtils.h" |
Georgios Pinitas | 7891a73 | 2021-08-20 21:39:25 +0100 | [diff] [blame] | 31 | #include "src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h" |
| 32 | #include "src/cpu/kernels/assembly/arm_gemm.hpp" |
| 33 | #include "src/cpu/utils/CpuAuxTensorHandler.h" |
Michele Di Giorgio | 6ad60af | 2020-06-09 14:52:15 +0100 | [diff] [blame] | 34 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 35 | #include <arm_neon.h> |
| 36 | |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 37 | namespace arm_compute |
| 38 | { |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 39 | namespace cpu |
| 40 | { |
SiCong Li | dba672c | 2023-04-06 16:30:18 +0100 | [diff] [blame^] | 41 | namespace |
| 42 | { |
| 43 | /** Run pretranspose_B_array in parallel (1D static scheduling) |
| 44 | * |
| 45 | * @tparam TypeInput |
| 46 | * @tparam TypeOutput |
| 47 | * |
| 48 | * @param[in] gemm_asm GemmCommon kernel to run |
| 49 | * @param[in] dst Pretransposed B array |
| 50 | * @param[in] src B array to be pretransposed |
| 51 | * @param[in] src_ld Stride in y |
| 52 | * @param[in] src_multi_stride Stride in z ("multi") |
| 53 | * @param[in] num_threads Number of threads to run this method. Must be >= 1 |
| 54 | */ |
| 55 | template <typename TypeInput, typename TypeOutput> |
| 56 | void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutput> *gemm_asm, ITensor *dst, const TypeInput *src, int src_ld, int src_multi_stride, unsigned int num_threads) |
| 57 | { |
| 58 | ARM_COMPUTE_ERROR_ON(gemm_asm == nullptr); |
| 59 | ARM_COMPUTE_ERROR_ON(num_threads == 0); |
| 60 | // The window size is also the total workload size |
| 61 | const unsigned int wsize = gemm_asm->get_B_pretranspose_window_size(); |
| 62 | |
| 63 | std::vector<IScheduler::Workload> workloads(num_threads); |
| 64 | for(unsigned int t = 0; t < num_threads; ++t) |
| 65 | { |
| 66 | workloads[t] = [ = ](const ThreadInfo & info) |
| 67 | { |
| 68 | const unsigned int start = (info.thread_id * wsize) / num_threads; |
| 69 | const unsigned int end = ((info.thread_id + 1) * wsize) / num_threads; |
| 70 | |
| 71 | if(start < end) |
| 72 | { |
| 73 | gemm_asm->pretranspose_B_array_part(dst->buffer(), src, src_ld, src_multi_stride, start, end); |
| 74 | } |
| 75 | }; |
| 76 | } |
| 77 | NEScheduler::get().run_tagged_workloads(workloads, "CpuGemmAssemblyDispatch/pretranspose_B_array"); |
| 78 | } |
| 79 | } // namespace |
| 80 | |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 81 | using namespace arm_compute::experimental; |
| 82 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 83 | namespace |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 84 | { |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 85 | struct free_delete |
| 86 | { |
| 87 | void operator()(void *x) |
| 88 | { |
| 89 | free(x); |
| 90 | } |
| 91 | }; |
| 92 | |
| 93 | struct Params |
| 94 | { |
| 95 | unsigned int M; |
| 96 | unsigned int N; |
| 97 | unsigned int K; |
| 98 | unsigned int batches; |
| 99 | unsigned int multis; |
| 100 | unsigned int sections; |
| 101 | bool indirect; |
| 102 | }; |
| 103 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 104 | Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 105 | { |
| 106 | ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 107 | Params p; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 108 | p.M = d->tensor_shape().y(); |
| 109 | p.K = a->tensor_shape().x(); |
| 110 | p.N = d->tensor_shape().x(); |
Georgios Pinitas | 4c634e0 | 2020-12-01 02:17:19 +0000 | [diff] [blame] | 111 | p.batches = 1; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 112 | p.multis = 1; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 113 | p.sections = 1; |
Georgios Pinitas | 4c634e0 | 2020-12-01 02:17:19 +0000 | [diff] [blame] | 114 | p.indirect = false; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 115 | |
| 116 | if(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect) |
| 117 | { |
| 118 | p.indirect = true; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 119 | p.sections = b->tensor_shape()[2] * b->tensor_shape()[3]; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 120 | } |
| 121 | else |
| 122 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 123 | p.multis = b->tensor_shape().z(); |
| 124 | p.batches = d->tensor_shape().total_size_upper(2) / p.multis; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 125 | } |
| 126 | |
| 127 | // Update M in case of GEMM3D for output |
| 128 | if(info.depth_output_gemm3d != 0) |
| 129 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 130 | p.M = d->tensor_shape().y() * d->tensor_shape().z(); |
| 131 | p.batches = d->tensor_shape().total_size_upper(3) / p.multis; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 132 | } |
| 133 | |
| 134 | return p; |
| 135 | } |
| 136 | |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 137 | IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataType data_type) |
| 138 | { |
| 139 | // Schedule assembly kernel |
| 140 | const int granule_threshold = 200; |
| 141 | IScheduler::Hints scheduling_hint = IScheduler::Hints(Window::DimX); |
| 142 | if(method == arm_gemm::GemmMethod::GEMM_INTERLEAVED && data_type == DataType::F32) |
| 143 | { |
| 144 | scheduling_hint = IScheduler::Hints(Window::DimX, IScheduler::StrategyHint::DYNAMIC, granule_threshold); |
| 145 | } |
| 146 | else if(method == arm_gemm::GemmMethod::GEMM_INTERLEAVED_2D && (data_type == DataType::F32 || data_type == DataType::F16 || data_type == DataType::U8 || data_type == DataType::S8)) |
| 147 | { |
| 148 | //GEMM_INTERLEAVED supports 2D parallelism, IScheduler::split_dimensions_all signals to parallelise over all window dimensions |
| 149 | scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold); |
| 150 | } |
| 151 | else if(method == arm_gemm::GemmMethod::QUANTIZE_WRAPPER_2D && (data_type == DataType::QASYMM8 || data_type == DataType::QASYMM8_SIGNED)) |
| 152 | { |
| 153 | //special case for QASYMM8 to support 2D parallelism, scheduler here may be tweaked differently compared to FP32 case |
| 154 | scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold); |
| 155 | } |
| 156 | |
| 157 | return scheduling_hint; |
| 158 | } |
| 159 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 160 | /** Fallback in case ACL doesn't have a function */ |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 161 | template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing> |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 162 | class Fallback : public CpuGemmAssemblyDispatch::IFallback |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 163 | { |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 164 | public: |
Michalis Spyrou | 1a569a3 | 2019-09-10 17:20:34 +0100 | [diff] [blame] | 165 | /** Destructor */ |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 166 | ~Fallback() = default; |
Michalis Spyrou | 1a569a3 | 2019-09-10 17:20:34 +0100 | [diff] [blame] | 167 | |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 168 | /** Initialise the functions's input and output. |
| 169 | * |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 170 | * @param[in] a Input tensor containing the Matrix A. |
| 171 | * @param[in] b Input tensor containing the Matrix B. |
| 172 | * @param[in] c Input tensor containing the Matrix C. |
| 173 | * @param[out] d Output tensor to store the result of matrix multiplication. |
| 174 | * @param[in] args Matrix multiplication information. |
| 175 | * @param[in] gemm_info GEMM meta-data |
| 176 | * @param[in] os Output stage meta-data. |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 177 | */ |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 178 | void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 179 | arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info, |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 180 | const OutputStage &os = {}); |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 181 | |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 182 | /** Set requantization shifts to be used |
| 183 | * |
| 184 | * @param[in] shifts Requantization shifts |
| 185 | * |
| 186 | * @return Pointer to the shift data |
| 187 | */ |
| 188 | /** Set requantization data to be used |
| 189 | * |
| 190 | * |
| 191 | * @param shifts Requantization shifts |
| 192 | * @param multipliers Requantization multipliers |
| 193 | * |
| 194 | * @return A tuple with the pointers to the shift and multiplier data respectively |
| 195 | */ |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 196 | std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> set_requantize_data(const std::vector<int32_t> &shifts, |
| 197 | const std::vector<int32_t> &multipliers); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 198 | |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 199 | // Inherited methods overridden: |
Mohammed Suhail Munshi | 4b5f6ef | 2022-10-21 11:15:54 +0100 | [diff] [blame] | 200 | void run(ITensorPack &tensors) override; |
| 201 | void prepare(ITensorPack &tensors) override; |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 202 | bool is_configured() const override; |
| 203 | experimental::MemoryRequirements workspace() const override; |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 204 | bool isVarWeightsKernel() const override |
| 205 | { |
| 206 | if(!_gemm_kernel_asm) |
| 207 | return false; |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 208 | const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format); |
| 209 | return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY; |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 210 | } |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 211 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 212 | private: |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 213 | enum AuxTensorIdx |
| 214 | { |
| 215 | AsmGemmWorkspace = 0, |
| 216 | Pretranspose, |
| 217 | Count |
| 218 | }; |
| 219 | |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 220 | /** Configure the indirect buffer |
| 221 | * |
| 222 | * @param[in] a Input tensor containing the Matrix A. |
| 223 | * @param[in] b Input tensor containing the Matrix B. |
| 224 | * @param[out] d Output tensor to store the result of matrix multiplication. |
| 225 | * @param[in] info GEMM meta-data |
| 226 | */ |
| 227 | void configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info); |
| 228 | /** Prepare the indirect buffer */ |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 229 | void prepare_indirect_buffer(ITensorPack &tensors); |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 230 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 231 | /** Assembly Gemm kernel */ |
Michalis Spyrou | 1a569a3 | 2019-09-10 17:20:34 +0100 | [diff] [blame] | 232 | std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr }; |
Michele Di Giorgio | 33f41fa | 2021-03-09 14:09:08 +0000 | [diff] [blame] | 233 | /** Optimised Arm® Neon™ kernel */ |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 234 | std::unique_ptr<INEKernel> _optimised_kernel{ nullptr }; |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 235 | /** Assembly GEMM workspace tensor info */ |
| 236 | TensorInfo _workspace_info{}; |
| 237 | /** Pre-transpose tensor info */ |
| 238 | TensorInfo _pretranspose_info{}; |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 239 | /** Prepared flag */ |
| 240 | bool _is_prepared{ false }; |
Georgios Pinitas | 37d080f | 2019-06-21 18:43:12 +0100 | [diff] [blame] | 241 | /** GEMM meta-data */ |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 242 | AsmGemmInfo _gemm_info{}; |
Georgios Pinitas | 77d4252 | 2019-11-05 13:35:47 +0000 | [diff] [blame] | 243 | /** GEMM kernel description */ |
| 244 | arm_gemm::KernelDescription _kernel_info{}; |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 245 | /** Per channel quantization shifts */ |
| 246 | std::vector<int32_t> _shifts{}; |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 247 | std::vector<int32_t> right_shifts{}; |
| 248 | std::vector<int32_t> left_shifts{}; |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 249 | /** Per channel quantization multipliers */ |
| 250 | std::vector<int32_t> _multipliers{}; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 251 | /** Indirect buffer */ |
| 252 | std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{}; |
| 253 | std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{}; |
Mohammed Suhail Munshi | 4b5f6ef | 2022-10-21 11:15:54 +0100 | [diff] [blame] | 254 | std::vector<TypeInput> _indirect_pad{}; |
| 255 | arm_gemm::ConvolutionParameters _cp{}; |
| 256 | experimental::MemoryRequirements _aux_mem{ Count }; |
| 257 | bool _B_pretranspose_required{ false }; |
| 258 | bool _is_b_constant{ true }; |
| 259 | bool _is_c_constant{ true }; |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 260 | }; |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 261 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 262 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 263 | std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> |
| 264 | Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, const std::vector<int32_t> &multipliers) |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 265 | { |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 266 | _multipliers = multipliers; |
| 267 | _shifts = shifts; |
| 268 | bool need_left = false; |
| 269 | for(const auto s : _shifts) |
| 270 | { |
| 271 | left_shifts.push_back(std::max(-s, int32_t(0))); |
| 272 | right_shifts.push_back(std::min(-s, int32_t(0))); |
morgolock | fa269bb | 2020-09-08 16:00:56 +0100 | [diff] [blame] | 273 | if(s < 0 && !need_left) |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 274 | { |
| 275 | need_left = true; |
| 276 | } |
| 277 | } |
| 278 | return std::make_tuple(need_left, left_shifts.data(), right_shifts.data(), _multipliers.data()); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 279 | } |
| 280 | |
| 281 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 282 | void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 283 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 284 | auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0); |
| 285 | const TypeInput *A_ptr = reinterpret_cast<TypeInput *>(a->buffer()); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 286 | const int multis = 1; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 287 | const int batches = a->info()->tensor_shape().total_size_upper(3); |
| 288 | const size_t stride_A = a->info()->strides_in_bytes().y() / sizeof(TypeInput); |
| 289 | const size_t batch_stride_A = a->info()->strides_in_bytes()[3] / sizeof(TypeInput); |
| 290 | const size_t multi_stride_A = a->info()->strides_in_bytes()[4] / sizeof(TypeInput); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 291 | |
| 292 | const size_t output_hw = _cp.output_height * _cp.output_width; |
| 293 | const int batch_size = _cp.kernel_height * _cp.kernel_width * output_hw * sizeof(TypeInput); |
| 294 | const size_t batch_stride = batch_size / sizeof(TypeInput); |
| 295 | const int multi_size = batch_size * batches; |
| 296 | const size_t multi_stride = multi_size / sizeof(TypeInput); |
| 297 | |
| 298 | for(int64_t m = 0; m < multis; m++) |
| 299 | { |
| 300 | for(int64_t b = 0; b < batches; b++) |
| 301 | { |
| 302 | for(int64_t output_y = 0; output_y < _cp.output_height; output_y++) |
| 303 | { |
| 304 | for(int64_t output_x = 0; output_x < _cp.output_width; output_x++) |
| 305 | { |
| 306 | int64_t output_xy = (output_y * _cp.output_width) + output_x; |
| 307 | |
| 308 | for(int64_t kernel_y = 0; kernel_y < _cp.kernel_height; kernel_y++) |
| 309 | { |
| 310 | for(int64_t kernel_x = 0; kernel_x < _cp.kernel_width; kernel_x++) |
| 311 | { |
| 312 | int64_t input_x = (output_x * _cp.output_stride_w) + kernel_x - _cp.padding_left; |
| 313 | int64_t input_y = (output_y * _cp.output_stride_h) + kernel_y - _cp.padding_top; |
| 314 | int64_t kernel_xy = (kernel_y * _cp.kernel_width) + kernel_x; |
| 315 | int64_t input_xy = (input_y * _cp.input_width) + input_x; |
| 316 | |
| 317 | if(input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height) |
| 318 | { |
| 319 | _indirect_buf.get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = _indirect_pad.data(); |
| 320 | } |
| 321 | else |
| 322 | { |
| 323 | _indirect_buf.get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = |
| 324 | A_ptr + (m * multi_stride_A + b * batch_stride_A + input_xy * stride_A); |
| 325 | } |
| 326 | } |
| 327 | } |
| 328 | } |
| 329 | } |
| 330 | } |
| 331 | } |
| 332 | } |
| 333 | |
| 334 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
| 335 | void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info) |
| 336 | { |
| 337 | ARM_COMPUTE_ERROR_ON(!(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)); |
| 338 | |
| 339 | float zeropad = 0.f; |
| 340 | if(is_data_type_quantized(a->data_type())) |
| 341 | { |
| 342 | zeropad = a->quantization_info().uniform().offset; |
| 343 | } |
| 344 | |
| 345 | const int64_t input_width = static_cast<int64_t>(a->tensor_shape()[1]); |
| 346 | const int64_t input_height = static_cast<int64_t>(a->tensor_shape()[2]); |
| 347 | const int64_t input_channels = static_cast<int64_t>(a->tensor_shape()[0]); |
| 348 | const int64_t kernel_width = static_cast<int64_t>(b->tensor_shape()[2]); |
| 349 | const int64_t kernel_height = static_cast<int64_t>(b->tensor_shape()[3]); |
| 350 | const int64_t output_width = static_cast<int64_t>(d->tensor_shape()[1]); |
| 351 | const int64_t output_height = static_cast<int64_t>(d->tensor_shape()[2]); |
| 352 | |
| 353 | _cp = { input_width, input_height, input_channels, kernel_width, kernel_height, output_width, output_height, |
| 354 | info.ps_info.stride().first, info.ps_info.stride().second, info.padding_top, info.padding_left, zeropad |
| 355 | }; |
| 356 | |
| 357 | if(info.method == AsmConvMethod::Conv) |
| 358 | { |
| 359 | _gemm_kernel_asm->set_convolution_parameters(_cp); |
| 360 | } |
| 361 | |
| 362 | if(info.method == AsmConvMethod::Indirect) |
| 363 | { |
| 364 | const unsigned int multis = 1; |
| 365 | const unsigned int batches = a->tensor_shape().total_size_upper(3); |
| 366 | const unsigned int kernel_hw = _cp.kernel_width * _cp.kernel_height; |
| 367 | const unsigned int output_hw = _cp.output_width * _cp.output_height; |
| 368 | |
| 369 | using TypeInputPtr = TypeInput *; |
| 370 | const int batch_size = kernel_hw * output_hw * sizeof(TypeInputPtr); |
| 371 | const size_t batch_stride = batch_size / sizeof(TypeInputPtr); |
| 372 | const int multi_size = batch_size * batches; |
| 373 | const size_t multi_stride = multi_size / sizeof(TypeInputPtr); |
| 374 | |
| 375 | _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>(reinterpret_cast<const TypeInput **>(malloc(multi_size * multis))); |
| 376 | _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>(reinterpret_cast<const TypeInput *const **>(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches))); |
Sang-Hoon Park | 8d5337e | 2021-01-15 14:36:25 +0000 | [diff] [blame] | 377 | _indirect_pad = std::vector<TypeInput>(_cp.input_channels, TypeInput(zeropad)); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 378 | |
| 379 | // Set indirect argument |
| 380 | int64_t pos = 0; |
| 381 | for(int64_t m = 0; m < multis; m++) |
| 382 | { |
| 383 | for(int64_t b = 0; b < batches; b++) |
| 384 | { |
| 385 | for(int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++) |
| 386 | { |
| 387 | (_indirect_arg.get())[pos++] = _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw; |
| 388 | } |
| 389 | } |
| 390 | } |
| 391 | |
| 392 | _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.get()); |
| 393 | } |
| 394 | } |
| 395 | |
| 396 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 397 | void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 398 | arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info, |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 399 | const OutputStage &os) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 400 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 401 | ARM_COMPUTE_UNUSED(c); |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 402 | |
| 403 | _is_b_constant = b->are_values_constant(); |
| 404 | _is_c_constant = c ? c->are_values_constant() : true; |
| 405 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 406 | _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput, OutputStage>(args, os); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 407 | if(_gemm_kernel_asm == nullptr) |
| 408 | { |
| 409 | //configuration not supported: Leave function unconfigured: |
| 410 | return; |
| 411 | } |
| 412 | |
Francesco.Petrogalli@arm.com | 193cad3 | 2022-03-07 13:39:21 +0000 | [diff] [blame] | 413 | arm_gemm::GemmConfig gemm_cfg = _gemm_kernel_asm->get_config(); |
| 414 | |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 415 | // arm_compute wrapper for the Gemm object (see above) |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 416 | auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>(); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 417 | ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr); |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 418 | acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter); |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 419 | const size_t workspace_size = _gemm_kernel_asm->get_working_size(); |
| 420 | const unsigned int alignment = 4096; |
| 421 | _workspace_info = TensorInfo(TensorShape(workspace_size), 1, DataType::U8); |
| 422 | _aux_mem[AsmGemmWorkspace] = MemoryInfo(offset_int_vec(AsmGemmWorkspace), MemoryLifetime::Temporary, workspace_size, alignment); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 423 | |
| 424 | //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and |
| 425 | //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001 |
| 426 | { |
Georgios Pinitas | 5aa1a0b | 2020-07-02 20:02:20 +0100 | [diff] [blame] | 427 | const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size(); |
Joseph Dobson | 6f8b17d | 2020-02-11 19:32:11 +0000 | [diff] [blame] | 428 | if(window_size < static_cast<unsigned int>(args._maxthreads)) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 429 | { |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 430 | _gemm_kernel_asm->set_nthreads(window_size); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 431 | } |
| 432 | } |
| 433 | |
| 434 | _optimised_kernel = std::move(acl_gemm_wrapper); |
Georgios Pinitas | 37d080f | 2019-06-21 18:43:12 +0100 | [diff] [blame] | 435 | _gemm_info = gemm_info; |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 436 | // Check for pre-transposed support |
| 437 | if(_gemm_kernel_asm->B_pretranspose_required()) |
| 438 | { |
| 439 | // Forcing 128-byte alignment (required by 32-bit kernels) |
| 440 | const unsigned int alignment = 128; |
| 441 | const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size(); |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 442 | _pretranspose_info = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8); |
| 443 | _aux_mem[Pretranspose] = MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment); |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 444 | _B_pretranspose_required = true; |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 445 | } |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 446 | |
| 447 | // Handle indirect GEMM convolution |
| 448 | if(gemm_info.method == AsmConvMethod::Conv || gemm_info.method == AsmConvMethod::Indirect) |
| 449 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 450 | configure_indirect(a, b, d, gemm_info); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 451 | } |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 452 | } |
| 453 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 454 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 455 | void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 456 | { |
| 457 | if(!_is_prepared) |
| 458 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 459 | auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1); |
| 460 | auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2); |
| 461 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 462 | // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C. |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 463 | if(c && c->info()->data_type() == DataType::S32) |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 464 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 465 | _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 466 | } |
| 467 | |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 468 | // Pretranspose B if required |
| 469 | if(_gemm_kernel_asm->B_pretranspose_required()) |
| 470 | { |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 471 | // Fixed format kernels need no pretranspose. |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 472 | ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format))); |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 473 | const int ldb = b->info()->strides_in_bytes().y() / b->info()->element_size(); |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 474 | const auto in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes()); |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 475 | const int multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size(); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 476 | |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 477 | CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false); |
| 478 | ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr); |
SiCong Li | dba672c | 2023-04-06 16:30:18 +0100 | [diff] [blame^] | 479 | run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads()); |
Georgios Pinitas | fa1db17 | 2021-08-12 06:28:09 +0100 | [diff] [blame] | 480 | |
| 481 | b->mark_as_unused(); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 482 | } |
| 483 | |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 484 | if(_gemm_info.method == AsmConvMethod::Indirect) |
| 485 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 486 | prepare_indirect_buffer(tensors); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 487 | } |
| 488 | |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 489 | _is_prepared = true; |
| 490 | } |
| 491 | } |
| 492 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 493 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 494 | bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 495 | { |
| 496 | return _optimised_kernel != nullptr; |
| 497 | } |
| 498 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 499 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 500 | experimental::MemoryRequirements Fallback<TypeInput, TypeOutput, OutputStage>::workspace() const |
| 501 | { |
| 502 | return _aux_mem; |
| 503 | } |
| 504 | |
| 505 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 506 | void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 507 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 508 | auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0); |
| 509 | auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1); |
| 510 | auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2); |
| 511 | auto d = tensors.get_tensor(TensorType::ACL_DST); |
| 512 | |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 513 | int lda = a->info()->strides_in_bytes().y() / a->info()->element_size(); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 514 | int ldb = 0; |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 515 | const int ldd = d->info()->strides_in_bytes().y() / d->info()->element_size(); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 516 | |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 517 | const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d != 0 ? 3 : 2; |
Georgios Pinitas | 37d080f | 2019-06-21 18:43:12 +0100 | [diff] [blame] | 518 | const size_t a_multi_idx = a_batch_idx + 1; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 519 | const size_t d_batch_idx = _gemm_info.depth_output_gemm3d != 0 ? 3 : 2; |
Georgios Pinitas | 37d080f | 2019-06-21 18:43:12 +0100 | [diff] [blame] | 520 | const size_t d_multi_idx = d_batch_idx + 1; |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 521 | |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 522 | int batch_stride_a = a->info()->strides_in_bytes()[a_batch_idx] / a->info()->element_size(); |
| 523 | const int batch_stride_d = d->info()->strides_in_bytes()[d_batch_idx] / d->info()->element_size(); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 524 | |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 525 | int multi_stride_a = a->info()->strides_in_bytes()[a_multi_idx] / a->info()->element_size(); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 526 | int multi_stride_b = 0; |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 527 | const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / d->info()->element_size(); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 528 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 529 | auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes()); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 530 | const TypeInput *in1_ptr = nullptr; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 531 | auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes()); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 532 | |
| 533 | // Check if B is pre-tranposed and de-reference if not |
| 534 | if(!_gemm_kernel_asm->B_is_pretransposed()) |
| 535 | { |
SiCong Li | dba672c | 2023-04-06 16:30:18 +0100 | [diff] [blame^] | 536 | ldb = b->info()->strides_in_bytes().y() / b->info()->element_size(); |
| 537 | multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size(); |
| 538 | in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes()); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 539 | } |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 540 | |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 541 | // If necessary, run pretranspose every time if either weights or biases are non-constant |
| 542 | if((b && !_is_b_constant) || (c && !_is_c_constant && c->info()->data_type() == DataType::S32)) |
| 543 | { |
| 544 | if(c && c->info()->data_type() == DataType::S32) |
| 545 | { |
| 546 | _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0); |
| 547 | } |
| 548 | |
| 549 | // Pretranspose B if required |
| 550 | if(_B_pretranspose_required) |
| 551 | { |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 552 | const int ldb = b->info()->strides_in_bytes().y() / b->info()->element_size(); |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 553 | const auto b_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes()); |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 554 | const int multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size(); |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 555 | |
| 556 | CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true); |
| 557 | ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr); |
| 558 | |
| 559 | if(_is_b_constant) |
| 560 | { |
| 561 | _gemm_kernel_asm->requantize_bias(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b); |
| 562 | } |
| 563 | else |
| 564 | { |
SiCong Li | dba672c | 2023-04-06 16:30:18 +0100 | [diff] [blame^] | 565 | run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads()); |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 566 | } |
| 567 | } |
| 568 | } |
| 569 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 570 | const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, d->info()->data_type()); |
Joseph Dobson | 6f8b17d | 2020-02-11 19:32:11 +0000 | [diff] [blame] | 571 | |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 572 | // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 573 | CpuAuxTensorHandler workspace(offset_int_vec(AsmGemmWorkspace), _workspace_info, tensors, false); |
| 574 | if(workspace.get()->buffer() != nullptr) |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 575 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 576 | _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(workspace.get()->buffer())); |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 577 | const unsigned int split_dim = scheduling_hint.split_dimension(); |
| 578 | const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size(); |
| 579 | unsigned int num_threads = NEScheduler::get().num_threads(); |
| 580 | if(window_size < num_threads) |
| 581 | { |
| 582 | num_threads = window_size; |
| 583 | } |
| 584 | if(split_dim != IScheduler::split_dimensions_all) |
| 585 | { |
| 586 | // Make sure the kernel does not expect more threads than we can actually spawn |
| 587 | const unsigned int num_iterations = _optimised_kernel.get()->window().num_iterations(split_dim); |
| 588 | num_threads = std::min(num_iterations, num_threads); |
| 589 | } |
| 590 | _gemm_kernel_asm->set_nthreads(num_threads); |
| 591 | } |
| 592 | |
| 593 | // Prepare assembly kernel |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 594 | prepare(tensors); |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 595 | |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 596 | // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C. |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 597 | TypeOutput *bias = nullptr; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 598 | if(c && c->info()->data_type() != DataType::S32) |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 599 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 600 | bias = reinterpret_cast<TypeOutput *>(c->buffer() + c->info()->offset_first_element_in_bytes()); |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 601 | } |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 602 | |
| 603 | if(_gemm_info.method == AsmConvMethod::Indirect) |
| 604 | { |
| 605 | in0_ptr = nullptr; |
| 606 | lda = 0; |
| 607 | batch_stride_a = 0; |
| 608 | multi_stride_a = 0; |
| 609 | } |
| 610 | |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 611 | // Set gemm parameters |
| 612 | _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, |
| 613 | in1_ptr, ldb, multi_stride_b, |
| 614 | out_ptr, ldd, batch_stride_d, multi_stride_d, |
| 615 | bias, 0); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 616 | // Schedule |
Georgios Pinitas | 77d4252 | 2019-11-05 13:35:47 +0000 | [diff] [blame] | 617 | NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 618 | } |
| 619 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 620 | template <typename TypeInput, typename TypeOutput> |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 621 | void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, |
| 622 | const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, |
| 623 | arm_gemm::Activation activation, const AsmGemmInfo &info) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 624 | { |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 625 | Params p = extract_parameters(a, b, d, info); |
| 626 | const CPUInfo &ci = NEScheduler::get().cpu_info(); |
| 627 | unsigned int num_threads = NEScheduler::get().num_threads(); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 628 | |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 629 | arm_gemm::GemmConfig cfg; |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 630 | cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 631 | arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 632 | |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 633 | // Create arm_gemm fallback |
Georgios Pinitas | 40f51a6 | 2020-11-21 03:04:18 +0000 | [diff] [blame] | 634 | auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>(); |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 635 | fallback->configure(a, b, c, d, args, info); |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 636 | arm_gemm = std::move(fallback); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 637 | } |
| 638 | |
| 639 | template <typename TypeInput, typename TypeOutput> |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 640 | void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, |
| 641 | const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, |
| 642 | arm_gemm::Activation activation, const AsmGemmInfo &info) |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 643 | { |
Michele Di Giorgio | 6ad60af | 2020-06-09 14:52:15 +0100 | [diff] [blame] | 644 | ARM_COMPUTE_UNUSED(activation); |
Georgios Pinitas | 4ee8b15 | 2021-07-16 16:16:43 +0100 | [diff] [blame] | 645 | Params p = extract_parameters(a, b, d, info); |
| 646 | const CPUInfo &ci = NEScheduler::get().cpu_info(); |
| 647 | const unsigned int num_threads = NEScheduler::get().num_threads(); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 648 | |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 649 | arm_gemm::GemmConfig cfg; |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 650 | cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 651 | arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 652 | |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 653 | // Create arm_gemm fallback |
Georgios Pinitas | 40f51a6 | 2020-11-21 03:04:18 +0000 | [diff] [blame] | 654 | auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>(); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 655 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 656 | // Configure requantization info |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 657 | const int32_t negation = info.negated_offsets ? 1 : -1; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 658 | const int32_t a_offset = -a->quantization_info().uniform().offset * negation; |
| 659 | const int32_t b_offset = -b->quantization_info().uniform().offset * negation; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 660 | const GEMMLowpOutputStageInfo os_info = info.output_stage; |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 661 | |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 662 | arm_gemm::Requantize32 gemm_requant_info{}; |
| 663 | if(os_info.gemmlowp_shifts.size() > 1) |
| 664 | { |
| 665 | const auto requantize_data = fallback->set_requantize_data(os_info.gemmlowp_shifts, os_info.gemmlowp_multipliers); |
| 666 | gemm_requant_info = arm_gemm::Requantize32(nullptr, 0, |
| 667 | a_offset, b_offset, os_info.gemmlowp_offset, |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 668 | (std::get<0>(requantize_data)) ? std::get<1>(requantize_data) : nullptr, |
| 669 | std::get<2>(requantize_data), |
| 670 | std::get<3>(requantize_data), |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 671 | os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); |
| 672 | } |
| 673 | else |
| 674 | { |
| 675 | gemm_requant_info = arm_gemm::Requantize32(nullptr, 0, |
| 676 | a_offset, b_offset, os_info.gemmlowp_offset, |
| 677 | -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier, |
| 678 | os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); |
| 679 | } |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 680 | |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 681 | // Configure fallback |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 682 | fallback->configure(a, b, c, d, args, info, gemm_requant_info); |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 683 | arm_gemm = std::move(fallback); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 684 | } |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 685 | } //namespace |
| 686 | |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 687 | CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch() |
| 688 | : _arm_gemm(nullptr) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 689 | { |
| 690 | } |
| 691 | |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 692 | Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 693 | const AsmGemmInfo &info) |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 694 | { |
| 695 | ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); |
| 696 | ARM_COMPUTE_UNUSED(c); |
| 697 | arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(info.activation_info); |
| 698 | Params p = extract_parameters(a, b, d, info); |
| 699 | const CPUInfo &ci = NEScheduler::get().cpu_info(); |
| 700 | unsigned int num_threads = NEScheduler::get().num_threads(); |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 701 | arm_gemm::GemmConfig cfg; |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 702 | cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); |
| 703 | arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format); |
| 704 | arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 705 | switch(a->data_type()) |
| 706 | { |
| 707 | case DataType::F32: |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 708 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 709 | "We could not find an optimized kernel for F32 input"); |
| 710 | break; |
| 711 | #ifdef __aarch64__ |
| 712 | case DataType::U8: |
| 713 | case DataType::QASYMM8: |
| 714 | if(d->data_type() == DataType::S32) |
| 715 | { |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 716 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), |
Ramy Elgammal | c8cc024 | 2022-10-05 17:05:20 +0100 | [diff] [blame] | 717 | "We could not find an optimized kernel for U8/QASYMM8 input and U32 output"); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 718 | } |
| 719 | else |
| 720 | { |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 721 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})), |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 722 | "We could not find an optimized kernel for U8 input and U8 output"); |
| 723 | } |
| 724 | break; |
| 725 | case DataType::S8: |
| 726 | case DataType::QASYMM8_SIGNED: |
| 727 | if(d->data_type() == DataType::S32) |
| 728 | { |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 729 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 730 | "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output"); |
| 731 | } |
| 732 | else |
| 733 | { |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 734 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})), |
Ramy Elgammal | c8cc024 | 2022-10-05 17:05:20 +0100 | [diff] [blame] | 735 | "We could not find an optimized kernel for S8 input and S8 output"); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 736 | } |
| 737 | break; |
| 738 | #endif /* __aarch64__ */ |
Pablo Marquez Tello | d208f4f | 2022-07-19 12:19:46 +0100 | [diff] [blame] | 739 | #if defined(ARM_COMPUTE_ENABLE_BF16) |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 740 | case DataType::BFLOAT16: |
| 741 | { |
Ramy Elgammal | aa52b7d | 2022-07-27 10:44:04 +0100 | [diff] [blame] | 742 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 743 | "We could not find an optimized kernel for BFLOAT16 input and F32 output"); |
| 744 | break; |
| 745 | } |
Pablo Marquez Tello | d208f4f | 2022-07-19 12:19:46 +0100 | [diff] [blame] | 746 | #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 747 | #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC |
| 748 | case DataType::F16: |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 749 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), |
Ramy Elgammal | c8cc024 | 2022-10-05 17:05:20 +0100 | [diff] [blame] | 750 | "We could not find an optimized kernel for F16 input and F16 output"); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 751 | break; |
| 752 | #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ |
| 753 | default: |
| 754 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel"); |
| 755 | break; |
| 756 | } |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 757 | expected_weight_format = assembly_utils::map_to_arm_compute_weight_format(arm_gemm_expected_wf); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 758 | |
| 759 | return Status{}; |
| 760 | } |
| 761 | |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 762 | Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 763 | { |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 764 | ARM_COMPUTE_UNUSED(c, info); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 765 | ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d); |
| 766 | ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 767 | ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a); |
Mohammed Suhail Munshi | 4b5f6ef | 2022-10-21 11:15:54 +0100 | [diff] [blame] | 768 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(info.reshape_b_only_on_first_run), "Assembly kernel will not be executed when reshape_b_only_on_first_run is false"); |
Georgios Pinitas | 0f954eb | 2020-06-23 17:28:38 +0100 | [diff] [blame] | 769 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 770 | #ifndef __aarch64__ |
Michele Di Giorgio | 5255672 | 2019-12-23 16:35:12 +0000 | [diff] [blame] | 771 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64"); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 772 | #endif /* __aarch64__ */ |
Georgios Pinitas | dbdea0d | 2019-10-16 19:21:40 +0100 | [diff] [blame] | 773 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S8, |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 774 | DataType::BFLOAT16, DataType::F16, DataType::F32); |
Georgios Pinitas | dbdea0d | 2019-10-16 19:21:40 +0100 | [diff] [blame] | 775 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(b, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::S8, |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 776 | DataType::BFLOAT16, DataType::F16, DataType::F32); |
Georgios Pinitas | dbdea0d | 2019-10-16 19:21:40 +0100 | [diff] [blame] | 777 | if(is_data_type_quantized_per_channel(b->data_type())) |
| 778 | { |
| 779 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QASYMM8_SIGNED, DataType::S8); |
| 780 | } |
Jonathan Deakin | 464ed20 | 2023-01-12 11:41:14 +0000 | [diff] [blame] | 781 | else if(is_fixed_format_fast_math(info.weight_format)) |
| 782 | { |
| 783 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32); |
| 784 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16); |
| 785 | } |
Georgios Pinitas | dbdea0d | 2019-10-16 19:21:40 +0100 | [diff] [blame] | 786 | else |
| 787 | { |
| 788 | ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b); |
| 789 | } |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 790 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F32 && d->data_type() != DataType::F32, "Only F32 output supported for F32 input"); |
| 791 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16, "Only F16 output supported for F16 input"); |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 792 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 && d->data_type() != DataType::F32, "Only F32 output supported for BFLOAT16 input"); |
Anthony Barbier | 9036749 | 2018-08-01 13:56:08 +0100 | [diff] [blame] | 793 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input"); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 794 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input"); |
Ethan Doe | 1fe48ca | 2023-03-01 23:19:26 +0000 | [diff] [blame] | 795 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32), |
| 796 | "Only QASYMM8/S32 output supported for QASYMM8 input"); |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 797 | arm_compute::WeightFormat expected_weight_format; |
| 798 | const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info); |
| 799 | if((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY) |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 800 | { |
| 801 | // Correctness check: if the format expected by the kernel is |
| 802 | // not "any", make sure that the one found matches the format |
| 803 | // intended by the caller. |
| 804 | ARM_COMPUTE_RETURN_ERROR_ON_MSG((expected_weight_format != info.weight_format), |
| 805 | "The format expected by the kernel does not correspond with the one requested by the user."); |
| 806 | } |
| 807 | return ret; |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 808 | } |
| 809 | |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 810 | bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation) |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 811 | { |
Michele Di Giorgio | d02d5ed | 2021-01-22 09:47:04 +0000 | [diff] [blame] | 812 | arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(activation); |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 813 | return act.type != arm_gemm::Activation::Type::None; |
| 814 | } |
| 815 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 816 | void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, const AsmGemmInfo &info) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 817 | { |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 818 | ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); |
Michele Di Giorgio | d02d5ed | 2021-01-22 09:47:04 +0000 | [diff] [blame] | 819 | arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(info.activation_info); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 820 | |
| 821 | //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured() |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 822 | if(!CpuGemmAssemblyDispatch::validate(a, b, c, d, info)) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 823 | { |
| 824 | return; |
| 825 | } |
| 826 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 827 | switch(a->data_type()) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 828 | { |
| 829 | case DataType::F32: |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 830 | create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 831 | break; |
| 832 | #ifdef __aarch64__ |
| 833 | case DataType::U8: |
| 834 | case DataType::QASYMM8: |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 835 | if(d->data_type() == DataType::S32) |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 836 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 837 | create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 838 | } |
| 839 | else |
| 840 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 841 | create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 842 | } |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 843 | break; |
| 844 | case DataType::S8: |
Georgios Pinitas | dbdea0d | 2019-10-16 19:21:40 +0100 | [diff] [blame] | 845 | case DataType::QASYMM8_SIGNED: |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 846 | if(d->data_type() == DataType::S32) |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 847 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 848 | create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 849 | } |
| 850 | else |
| 851 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 852 | create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 853 | } |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 854 | break; |
| 855 | #endif /* __aarch64__ */ |
Pablo Marquez Tello | d208f4f | 2022-07-19 12:19:46 +0100 | [diff] [blame] | 856 | #if defined(ARM_COMPUTE_ENABLE_BF16) |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 857 | case DataType::BFLOAT16: |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 858 | create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info); |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 859 | break; |
Pablo Marquez Tello | d208f4f | 2022-07-19 12:19:46 +0100 | [diff] [blame] | 860 | #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 861 | #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC |
| 862 | case DataType::F16: |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 863 | create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 864 | break; |
| 865 | #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ |
| 866 | default: |
| 867 | break; |
| 868 | } |
| 869 | } |
| 870 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 871 | void CpuGemmAssemblyDispatch::prepare(ITensorPack &tensors) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 872 | { |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 873 | ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 874 | _arm_gemm->prepare(tensors); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 875 | } |
| 876 | |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 877 | bool CpuGemmAssemblyDispatch::is_configured() const |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 878 | { |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 879 | return _arm_gemm && _arm_gemm->is_configured(); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 880 | } |
| 881 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 882 | void CpuGemmAssemblyDispatch::run(ITensorPack &tensors) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 883 | { |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 884 | ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 885 | _arm_gemm->run(tensors); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 886 | } |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 887 | |
| 888 | experimental::MemoryRequirements CpuGemmAssemblyDispatch::workspace() const |
| 889 | { |
| 890 | ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); |
| 891 | return _arm_gemm->workspace(); |
| 892 | } |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 893 | } // namespace cpu |
| 894 | } // namespace arm_compute |