Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 1 | /* |
Francesco.Petrogalli@arm.com | 193cad3 | 2022-03-07 13:39:21 +0000 | [diff] [blame] | 2 | * Copyright (c) 2018-2022 Arm Limited. |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 3 | * |
| 4 | * SPDX-License-Identifier: MIT |
| 5 | * |
| 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | * of this software and associated documentation files (the "Software"), to |
| 8 | * deal in the Software without restriction, including without limitation the |
| 9 | * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 10 | * sell copies of the Software, and to permit persons to whom the Software is |
| 11 | * furnished to do so, subject to the following conditions: |
| 12 | * |
| 13 | * The above copyright notice and this permission notice shall be included in all |
| 14 | * copies or substantial portions of the Software. |
| 15 | * |
| 16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | * SOFTWARE. |
| 23 | */ |
Georgios Pinitas | 7891a73 | 2021-08-20 21:39:25 +0100 | [diff] [blame] | 24 | #include "src/cpu/operators/internal/CpuGemmAssemblyDispatch.h" |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 25 | |
Sang-Hoon Park | 68dd25f | 2020-10-19 16:00:11 +0100 | [diff] [blame] | 26 | #include "arm_compute/runtime/NEON/NEScheduler.h" |
| 27 | #include "src/core/CPP/Validate.h" |
Pablo Marquez Tello | 93581a5 | 2022-07-21 13:55:27 +0100 | [diff] [blame] | 28 | #include "src/core/NEON/kernels/arm_gemm/utils.hpp" |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 29 | #include "src/core/helpers/MemoryHelpers.h" |
Michele Di Giorgio | d02d5ed | 2021-01-22 09:47:04 +0000 | [diff] [blame] | 30 | #include "src/core/utils/AssemblyUtils.h" |
Georgios Pinitas | 7891a73 | 2021-08-20 21:39:25 +0100 | [diff] [blame] | 31 | #include "src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h" |
| 32 | #include "src/cpu/kernels/assembly/arm_gemm.hpp" |
| 33 | #include "src/cpu/utils/CpuAuxTensorHandler.h" |
Michele Di Giorgio | 6ad60af | 2020-06-09 14:52:15 +0100 | [diff] [blame] | 34 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 35 | #include <arm_neon.h> |
| 36 | |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 37 | namespace arm_compute |
| 38 | { |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 39 | namespace cpu |
| 40 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 41 | using namespace arm_compute::experimental; |
| 42 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 43 | namespace |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 44 | { |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 45 | struct free_delete |
| 46 | { |
| 47 | void operator()(void *x) |
| 48 | { |
| 49 | free(x); |
| 50 | } |
| 51 | }; |
| 52 | |
| 53 | struct Params |
| 54 | { |
| 55 | unsigned int M; |
| 56 | unsigned int N; |
| 57 | unsigned int K; |
| 58 | unsigned int batches; |
| 59 | unsigned int multis; |
| 60 | unsigned int sections; |
| 61 | bool indirect; |
| 62 | }; |
| 63 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 64 | Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 65 | { |
| 66 | ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 67 | Params p; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 68 | p.M = d->tensor_shape().y(); |
| 69 | p.K = a->tensor_shape().x(); |
| 70 | p.N = d->tensor_shape().x(); |
Georgios Pinitas | 4c634e0 | 2020-12-01 02:17:19 +0000 | [diff] [blame] | 71 | p.batches = 1; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 72 | p.multis = 1; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 73 | p.sections = 1; |
Georgios Pinitas | 4c634e0 | 2020-12-01 02:17:19 +0000 | [diff] [blame] | 74 | p.indirect = false; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 75 | |
| 76 | if(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect) |
| 77 | { |
| 78 | p.indirect = true; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 79 | p.sections = b->tensor_shape()[2] * b->tensor_shape()[3]; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 80 | } |
| 81 | else |
| 82 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 83 | p.multis = b->tensor_shape().z(); |
| 84 | p.batches = d->tensor_shape().total_size_upper(2) / p.multis; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 85 | } |
| 86 | |
| 87 | // Update M in case of GEMM3D for output |
| 88 | if(info.depth_output_gemm3d != 0) |
| 89 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 90 | p.M = d->tensor_shape().y() * d->tensor_shape().z(); |
| 91 | p.batches = d->tensor_shape().total_size_upper(3) / p.multis; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 92 | } |
| 93 | |
| 94 | return p; |
| 95 | } |
| 96 | |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 97 | IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataType data_type) |
| 98 | { |
| 99 | // Schedule assembly kernel |
| 100 | const int granule_threshold = 200; |
| 101 | IScheduler::Hints scheduling_hint = IScheduler::Hints(Window::DimX); |
| 102 | if(method == arm_gemm::GemmMethod::GEMM_INTERLEAVED && data_type == DataType::F32) |
| 103 | { |
| 104 | scheduling_hint = IScheduler::Hints(Window::DimX, IScheduler::StrategyHint::DYNAMIC, granule_threshold); |
| 105 | } |
| 106 | else if(method == arm_gemm::GemmMethod::GEMM_INTERLEAVED_2D && (data_type == DataType::F32 || data_type == DataType::F16 || data_type == DataType::U8 || data_type == DataType::S8)) |
| 107 | { |
| 108 | //GEMM_INTERLEAVED supports 2D parallelism, IScheduler::split_dimensions_all signals to parallelise over all window dimensions |
| 109 | scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold); |
| 110 | } |
| 111 | else if(method == arm_gemm::GemmMethod::QUANTIZE_WRAPPER_2D && (data_type == DataType::QASYMM8 || data_type == DataType::QASYMM8_SIGNED)) |
| 112 | { |
| 113 | //special case for QASYMM8 to support 2D parallelism, scheduler here may be tweaked differently compared to FP32 case |
| 114 | scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold); |
| 115 | } |
| 116 | |
| 117 | return scheduling_hint; |
| 118 | } |
| 119 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 120 | /** Fallback in case ACL doesn't have a function */ |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 121 | template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing> |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 122 | class Fallback : public CpuGemmAssemblyDispatch::IFallback |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 123 | { |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 124 | public: |
Michalis Spyrou | 1a569a3 | 2019-09-10 17:20:34 +0100 | [diff] [blame] | 125 | /** Destructor */ |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 126 | ~Fallback() = default; |
Michalis Spyrou | 1a569a3 | 2019-09-10 17:20:34 +0100 | [diff] [blame] | 127 | |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 128 | /** Initialise the functions's input and output. |
| 129 | * |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 130 | * @param[in] a Input tensor containing the Matrix A. |
| 131 | * @param[in] b Input tensor containing the Matrix B. |
| 132 | * @param[in] c Input tensor containing the Matrix C. |
| 133 | * @param[out] d Output tensor to store the result of matrix multiplication. |
| 134 | * @param[in] args Matrix multiplication information. |
| 135 | * @param[in] gemm_info GEMM meta-data |
| 136 | * @param[in] os Output stage meta-data. |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 137 | */ |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 138 | void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 139 | arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info, |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 140 | const OutputStage &os = {}); |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 141 | |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 142 | /** Set requantization shifts to be used |
| 143 | * |
| 144 | * @param[in] shifts Requantization shifts |
| 145 | * |
| 146 | * @return Pointer to the shift data |
| 147 | */ |
| 148 | /** Set requantization data to be used |
| 149 | * |
| 150 | * |
| 151 | * @param shifts Requantization shifts |
| 152 | * @param multipliers Requantization multipliers |
| 153 | * |
| 154 | * @return A tuple with the pointers to the shift and multiplier data respectively |
| 155 | */ |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 156 | std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> set_requantize_data(const std::vector<int32_t> &shifts, |
| 157 | const std::vector<int32_t> &multipliers); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 158 | |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 159 | // Inherited methods overridden: |
Mohammed Suhail Munshi | 4b5f6ef | 2022-10-21 11:15:54 +0100 | [diff] [blame] | 160 | void run(ITensorPack &tensors) override; |
| 161 | void prepare(ITensorPack &tensors) override; |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 162 | bool is_configured() const override; |
| 163 | experimental::MemoryRequirements workspace() const override; |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 164 | bool isVarWeightsKernel() const override |
| 165 | { |
| 166 | if(!_gemm_kernel_asm) |
| 167 | return false; |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 168 | const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format); |
| 169 | return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY; |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 170 | } |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 171 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 172 | private: |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 173 | enum AuxTensorIdx |
| 174 | { |
| 175 | AsmGemmWorkspace = 0, |
| 176 | Pretranspose, |
| 177 | Count |
| 178 | }; |
| 179 | |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 180 | /** Configure the indirect buffer |
| 181 | * |
| 182 | * @param[in] a Input tensor containing the Matrix A. |
| 183 | * @param[in] b Input tensor containing the Matrix B. |
| 184 | * @param[out] d Output tensor to store the result of matrix multiplication. |
| 185 | * @param[in] info GEMM meta-data |
| 186 | */ |
| 187 | void configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info); |
| 188 | /** Prepare the indirect buffer */ |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 189 | void prepare_indirect_buffer(ITensorPack &tensors); |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 190 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 191 | /** Assembly Gemm kernel */ |
Michalis Spyrou | 1a569a3 | 2019-09-10 17:20:34 +0100 | [diff] [blame] | 192 | std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr }; |
Michele Di Giorgio | 33f41fa | 2021-03-09 14:09:08 +0000 | [diff] [blame] | 193 | /** Optimised Arm® Neon™ kernel */ |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 194 | std::unique_ptr<INEKernel> _optimised_kernel{ nullptr }; |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 195 | /** Assembly GEMM workspace tensor info */ |
| 196 | TensorInfo _workspace_info{}; |
| 197 | /** Pre-transpose tensor info */ |
| 198 | TensorInfo _pretranspose_info{}; |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 199 | /** Prepared flag */ |
| 200 | bool _is_prepared{ false }; |
Georgios Pinitas | 37d080f | 2019-06-21 18:43:12 +0100 | [diff] [blame] | 201 | /** GEMM meta-data */ |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 202 | AsmGemmInfo _gemm_info{}; |
Georgios Pinitas | 77d4252 | 2019-11-05 13:35:47 +0000 | [diff] [blame] | 203 | /** GEMM kernel description */ |
| 204 | arm_gemm::KernelDescription _kernel_info{}; |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 205 | /** Per channel quantization shifts */ |
| 206 | std::vector<int32_t> _shifts{}; |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 207 | std::vector<int32_t> right_shifts{}; |
| 208 | std::vector<int32_t> left_shifts{}; |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 209 | /** Per channel quantization multipliers */ |
| 210 | std::vector<int32_t> _multipliers{}; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 211 | /** Indirect buffer */ |
| 212 | std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{}; |
| 213 | std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{}; |
Mohammed Suhail Munshi | 4b5f6ef | 2022-10-21 11:15:54 +0100 | [diff] [blame] | 214 | std::vector<TypeInput> _indirect_pad{}; |
| 215 | arm_gemm::ConvolutionParameters _cp{}; |
| 216 | experimental::MemoryRequirements _aux_mem{ Count }; |
| 217 | bool _B_pretranspose_required{ false }; |
| 218 | bool _is_b_constant{ true }; |
| 219 | bool _is_c_constant{ true }; |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 220 | }; |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 221 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 222 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 223 | std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> |
| 224 | Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, const std::vector<int32_t> &multipliers) |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 225 | { |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 226 | _multipliers = multipliers; |
| 227 | _shifts = shifts; |
| 228 | bool need_left = false; |
| 229 | for(const auto s : _shifts) |
| 230 | { |
| 231 | left_shifts.push_back(std::max(-s, int32_t(0))); |
| 232 | right_shifts.push_back(std::min(-s, int32_t(0))); |
morgolock | fa269bb | 2020-09-08 16:00:56 +0100 | [diff] [blame] | 233 | if(s < 0 && !need_left) |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 234 | { |
| 235 | need_left = true; |
| 236 | } |
| 237 | } |
| 238 | return std::make_tuple(need_left, left_shifts.data(), right_shifts.data(), _multipliers.data()); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 239 | } |
| 240 | |
| 241 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 242 | void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 243 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 244 | auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0); |
| 245 | const TypeInput *A_ptr = reinterpret_cast<TypeInput *>(a->buffer()); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 246 | const int multis = 1; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 247 | const int batches = a->info()->tensor_shape().total_size_upper(3); |
| 248 | const size_t stride_A = a->info()->strides_in_bytes().y() / sizeof(TypeInput); |
| 249 | const size_t batch_stride_A = a->info()->strides_in_bytes()[3] / sizeof(TypeInput); |
| 250 | const size_t multi_stride_A = a->info()->strides_in_bytes()[4] / sizeof(TypeInput); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 251 | |
| 252 | const size_t output_hw = _cp.output_height * _cp.output_width; |
| 253 | const int batch_size = _cp.kernel_height * _cp.kernel_width * output_hw * sizeof(TypeInput); |
| 254 | const size_t batch_stride = batch_size / sizeof(TypeInput); |
| 255 | const int multi_size = batch_size * batches; |
| 256 | const size_t multi_stride = multi_size / sizeof(TypeInput); |
| 257 | |
| 258 | for(int64_t m = 0; m < multis; m++) |
| 259 | { |
| 260 | for(int64_t b = 0; b < batches; b++) |
| 261 | { |
| 262 | for(int64_t output_y = 0; output_y < _cp.output_height; output_y++) |
| 263 | { |
| 264 | for(int64_t output_x = 0; output_x < _cp.output_width; output_x++) |
| 265 | { |
| 266 | int64_t output_xy = (output_y * _cp.output_width) + output_x; |
| 267 | |
| 268 | for(int64_t kernel_y = 0; kernel_y < _cp.kernel_height; kernel_y++) |
| 269 | { |
| 270 | for(int64_t kernel_x = 0; kernel_x < _cp.kernel_width; kernel_x++) |
| 271 | { |
| 272 | int64_t input_x = (output_x * _cp.output_stride_w) + kernel_x - _cp.padding_left; |
| 273 | int64_t input_y = (output_y * _cp.output_stride_h) + kernel_y - _cp.padding_top; |
| 274 | int64_t kernel_xy = (kernel_y * _cp.kernel_width) + kernel_x; |
| 275 | int64_t input_xy = (input_y * _cp.input_width) + input_x; |
| 276 | |
| 277 | if(input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height) |
| 278 | { |
| 279 | _indirect_buf.get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = _indirect_pad.data(); |
| 280 | } |
| 281 | else |
| 282 | { |
| 283 | _indirect_buf.get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = |
| 284 | A_ptr + (m * multi_stride_A + b * batch_stride_A + input_xy * stride_A); |
| 285 | } |
| 286 | } |
| 287 | } |
| 288 | } |
| 289 | } |
| 290 | } |
| 291 | } |
| 292 | } |
| 293 | |
| 294 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
| 295 | void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info) |
| 296 | { |
| 297 | ARM_COMPUTE_ERROR_ON(!(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)); |
| 298 | |
| 299 | float zeropad = 0.f; |
| 300 | if(is_data_type_quantized(a->data_type())) |
| 301 | { |
| 302 | zeropad = a->quantization_info().uniform().offset; |
| 303 | } |
| 304 | |
| 305 | const int64_t input_width = static_cast<int64_t>(a->tensor_shape()[1]); |
| 306 | const int64_t input_height = static_cast<int64_t>(a->tensor_shape()[2]); |
| 307 | const int64_t input_channels = static_cast<int64_t>(a->tensor_shape()[0]); |
| 308 | const int64_t kernel_width = static_cast<int64_t>(b->tensor_shape()[2]); |
| 309 | const int64_t kernel_height = static_cast<int64_t>(b->tensor_shape()[3]); |
| 310 | const int64_t output_width = static_cast<int64_t>(d->tensor_shape()[1]); |
| 311 | const int64_t output_height = static_cast<int64_t>(d->tensor_shape()[2]); |
| 312 | |
| 313 | _cp = { input_width, input_height, input_channels, kernel_width, kernel_height, output_width, output_height, |
| 314 | info.ps_info.stride().first, info.ps_info.stride().second, info.padding_top, info.padding_left, zeropad |
| 315 | }; |
| 316 | |
| 317 | if(info.method == AsmConvMethod::Conv) |
| 318 | { |
| 319 | _gemm_kernel_asm->set_convolution_parameters(_cp); |
| 320 | } |
| 321 | |
| 322 | if(info.method == AsmConvMethod::Indirect) |
| 323 | { |
| 324 | const unsigned int multis = 1; |
| 325 | const unsigned int batches = a->tensor_shape().total_size_upper(3); |
| 326 | const unsigned int kernel_hw = _cp.kernel_width * _cp.kernel_height; |
| 327 | const unsigned int output_hw = _cp.output_width * _cp.output_height; |
| 328 | |
| 329 | using TypeInputPtr = TypeInput *; |
| 330 | const int batch_size = kernel_hw * output_hw * sizeof(TypeInputPtr); |
| 331 | const size_t batch_stride = batch_size / sizeof(TypeInputPtr); |
| 332 | const int multi_size = batch_size * batches; |
| 333 | const size_t multi_stride = multi_size / sizeof(TypeInputPtr); |
| 334 | |
| 335 | _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>(reinterpret_cast<const TypeInput **>(malloc(multi_size * multis))); |
| 336 | _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>(reinterpret_cast<const TypeInput *const **>(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches))); |
Sang-Hoon Park | 8d5337e | 2021-01-15 14:36:25 +0000 | [diff] [blame] | 337 | _indirect_pad = std::vector<TypeInput>(_cp.input_channels, TypeInput(zeropad)); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 338 | |
| 339 | // Set indirect argument |
| 340 | int64_t pos = 0; |
| 341 | for(int64_t m = 0; m < multis; m++) |
| 342 | { |
| 343 | for(int64_t b = 0; b < batches; b++) |
| 344 | { |
| 345 | for(int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++) |
| 346 | { |
| 347 | (_indirect_arg.get())[pos++] = _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw; |
| 348 | } |
| 349 | } |
| 350 | } |
| 351 | |
| 352 | _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.get()); |
| 353 | } |
| 354 | } |
| 355 | |
| 356 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 357 | void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 358 | arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info, |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 359 | const OutputStage &os) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 360 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 361 | ARM_COMPUTE_UNUSED(c); |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 362 | |
| 363 | _is_b_constant = b->are_values_constant(); |
| 364 | _is_c_constant = c ? c->are_values_constant() : true; |
| 365 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 366 | _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput, OutputStage>(args, os); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 367 | if(_gemm_kernel_asm == nullptr) |
| 368 | { |
| 369 | //configuration not supported: Leave function unconfigured: |
| 370 | return; |
| 371 | } |
| 372 | |
Francesco.Petrogalli@arm.com | 193cad3 | 2022-03-07 13:39:21 +0000 | [diff] [blame] | 373 | arm_gemm::GemmConfig gemm_cfg = _gemm_kernel_asm->get_config(); |
| 374 | |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 375 | // arm_compute wrapper for the Gemm object (see above) |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 376 | auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>(); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 377 | ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr); |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 378 | acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter); |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 379 | const size_t workspace_size = _gemm_kernel_asm->get_working_size(); |
| 380 | const unsigned int alignment = 4096; |
| 381 | _workspace_info = TensorInfo(TensorShape(workspace_size), 1, DataType::U8); |
| 382 | _aux_mem[AsmGemmWorkspace] = MemoryInfo(offset_int_vec(AsmGemmWorkspace), MemoryLifetime::Temporary, workspace_size, alignment); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 383 | |
| 384 | //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and |
| 385 | //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001 |
| 386 | { |
Georgios Pinitas | 5aa1a0b | 2020-07-02 20:02:20 +0100 | [diff] [blame] | 387 | const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size(); |
Joseph Dobson | 6f8b17d | 2020-02-11 19:32:11 +0000 | [diff] [blame] | 388 | if(window_size < static_cast<unsigned int>(args._maxthreads)) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 389 | { |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 390 | _gemm_kernel_asm->set_nthreads(window_size); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 391 | } |
| 392 | } |
| 393 | |
| 394 | _optimised_kernel = std::move(acl_gemm_wrapper); |
Georgios Pinitas | 37d080f | 2019-06-21 18:43:12 +0100 | [diff] [blame] | 395 | _gemm_info = gemm_info; |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 396 | // Check for pre-transposed support |
| 397 | if(_gemm_kernel_asm->B_pretranspose_required()) |
| 398 | { |
| 399 | // Forcing 128-byte alignment (required by 32-bit kernels) |
| 400 | const unsigned int alignment = 128; |
| 401 | const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size(); |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 402 | _pretranspose_info = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8); |
| 403 | _aux_mem[Pretranspose] = MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment); |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 404 | _B_pretranspose_required = true; |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 405 | } |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 406 | |
| 407 | // Handle indirect GEMM convolution |
| 408 | if(gemm_info.method == AsmConvMethod::Conv || gemm_info.method == AsmConvMethod::Indirect) |
| 409 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 410 | configure_indirect(a, b, d, gemm_info); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 411 | } |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 412 | } |
| 413 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 414 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 415 | void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 416 | { |
| 417 | if(!_is_prepared) |
| 418 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 419 | auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1); |
| 420 | auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2); |
| 421 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 422 | // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C. |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 423 | if(c && c->info()->data_type() == DataType::S32) |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 424 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 425 | _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 426 | } |
| 427 | |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 428 | // Pretranspose B if required |
| 429 | if(_gemm_kernel_asm->B_pretranspose_required()) |
| 430 | { |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 431 | // Fixed format kernels need no pretranspose. |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 432 | ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format))); |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 433 | const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); |
| 434 | const auto in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes()); |
| 435 | const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 436 | |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 437 | CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false); |
| 438 | ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr); |
| 439 | _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), in1_ptr, ldb, multi_stride_b); |
Georgios Pinitas | fa1db17 | 2021-08-12 06:28:09 +0100 | [diff] [blame] | 440 | |
| 441 | b->mark_as_unused(); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 442 | } |
| 443 | |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 444 | if(_gemm_info.method == AsmConvMethod::Indirect) |
| 445 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 446 | prepare_indirect_buffer(tensors); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 447 | } |
| 448 | |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 449 | _is_prepared = true; |
| 450 | } |
| 451 | } |
| 452 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 453 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 454 | bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 455 | { |
| 456 | return _optimised_kernel != nullptr; |
| 457 | } |
| 458 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 459 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 460 | experimental::MemoryRequirements Fallback<TypeInput, TypeOutput, OutputStage>::workspace() const |
| 461 | { |
| 462 | return _aux_mem; |
| 463 | } |
| 464 | |
| 465 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 466 | void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 467 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 468 | auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0); |
| 469 | auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1); |
| 470 | auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2); |
| 471 | auto d = tensors.get_tensor(TensorType::ACL_DST); |
| 472 | |
| 473 | int lda = a->info()->strides_in_bytes().y() / sizeof(TypeInput); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 474 | int ldb = 0; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 475 | const int ldd = d->info()->strides_in_bytes().y() / sizeof(TypeOutput); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 476 | |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 477 | const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d != 0 ? 3 : 2; |
Georgios Pinitas | 37d080f | 2019-06-21 18:43:12 +0100 | [diff] [blame] | 478 | const size_t a_multi_idx = a_batch_idx + 1; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 479 | const size_t d_batch_idx = _gemm_info.depth_output_gemm3d != 0 ? 3 : 2; |
Georgios Pinitas | 37d080f | 2019-06-21 18:43:12 +0100 | [diff] [blame] | 480 | const size_t d_multi_idx = d_batch_idx + 1; |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 481 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 482 | int batch_stride_a = a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput); |
| 483 | const int batch_stride_d = d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 484 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 485 | int multi_stride_a = a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 486 | int multi_stride_b = 0; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 487 | const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 488 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 489 | auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes()); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 490 | const TypeInput *in1_ptr = nullptr; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 491 | auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes()); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 492 | |
| 493 | // Check if B is pre-tranposed and de-reference if not |
| 494 | if(!_gemm_kernel_asm->B_is_pretransposed()) |
| 495 | { |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 496 | ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); |
Milos Puzovic | 13b623e | 2022-07-27 17:53:21 +0000 | [diff] [blame] | 497 | multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 498 | const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format); |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 499 | if(is_fixed_format(wf)) |
| 500 | { |
| 501 | // The 4D tensor of dimension O'HWI' created for the |
| 502 | // OHWIo<interleave_by>i<block_by> format is in reality seen |
| 503 | // as a 2D tensor at arm_gemm level, where the rows are |
| 504 | // O'/<interleave_by> and the columns are <interleave_by> * |
| 505 | // H * W * I'. |
Milos Puzovic | 13b623e | 2022-07-27 17:53:21 +0000 | [diff] [blame] | 506 | ITensorInfo *tensor_info = b->info(); |
| 507 | const DataLayout data_layout = tensor_info->data_layout(); |
| 508 | const TensorShape tensor_shape = tensor_info->tensor_shape(); |
| 509 | const int tensor_height = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)]; |
| 510 | const int tensor_width = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)]; |
Pablo Marquez Tello | 93581a5 | 2022-07-21 13:55:27 +0100 | [diff] [blame] | 511 | int tensor_channels = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; |
Milos Puzovic | 13b623e | 2022-07-27 17:53:21 +0000 | [diff] [blame] | 512 | const int interleave_by = arm_compute::interleave_by(wf); |
Pablo Marquez Tello | 93581a5 | 2022-07-21 13:55:27 +0100 | [diff] [blame] | 513 | const int blocked_by = arm_compute::block_by(wf); |
Milos Puzovic | 13b623e | 2022-07-27 17:53:21 +0000 | [diff] [blame] | 514 | // We need to find a new stride that is distance from the data for one |
| 515 | // set of output channels to the next |
| 516 | if(ldb == tensor_channels && multi_stride_b == tensor_channels * tensor_width) |
| 517 | { |
| 518 | // In this case dimensions that are packed are height, width and channel |
| 519 | // so we need to stride it by interleave_by |
Pablo Marquez Tello | 93581a5 | 2022-07-21 13:55:27 +0100 | [diff] [blame] | 520 | if(tensor_channels % blocked_by != 0) |
| 521 | { |
| 522 | // We need to pad |
| 523 | tensor_channels = arm_gemm::iceildiv(tensor_channels, blocked_by) * blocked_by; |
| 524 | } |
Milos Puzovic | 13b623e | 2022-07-27 17:53:21 +0000 | [diff] [blame] | 525 | ldb = interleave_by * tensor_height * tensor_width * tensor_channels; |
| 526 | } |
| 527 | else if(multi_stride_b == 0 || (ldb == tensor_width && multi_stride_b == tensor_height * tensor_width)) |
| 528 | { |
| 529 | // In this case dimension that is packed is only height |
| 530 | // so we need to stride only height by interleave_by |
| 531 | ldb = interleave_by * tensor_height; |
| 532 | } |
| 533 | else |
| 534 | { |
| 535 | // If dimensions are not packed as above error is thrown |
| 536 | // as at the moment other forms of packing are not supported |
| 537 | ARM_COMPUTE_ERROR("Unsupported packing for fixed format kernel"); |
| 538 | } |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 539 | } |
Milos Puzovic | 13b623e | 2022-07-27 17:53:21 +0000 | [diff] [blame] | 540 | in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes()); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 541 | } |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 542 | |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 543 | // If necessary, run pretranspose every time if either weights or biases are non-constant |
| 544 | if((b && !_is_b_constant) || (c && !_is_c_constant && c->info()->data_type() == DataType::S32)) |
| 545 | { |
| 546 | if(c && c->info()->data_type() == DataType::S32) |
| 547 | { |
| 548 | _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0); |
| 549 | } |
| 550 | |
| 551 | // Pretranspose B if required |
| 552 | if(_B_pretranspose_required) |
| 553 | { |
| 554 | const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); |
| 555 | const auto b_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes()); |
| 556 | const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); |
| 557 | |
| 558 | CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true); |
| 559 | ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr); |
| 560 | |
| 561 | if(_is_b_constant) |
| 562 | { |
| 563 | _gemm_kernel_asm->requantize_bias(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b); |
| 564 | } |
| 565 | else |
| 566 | { |
| 567 | _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b); |
| 568 | } |
| 569 | } |
| 570 | } |
| 571 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 572 | const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, d->info()->data_type()); |
Joseph Dobson | 6f8b17d | 2020-02-11 19:32:11 +0000 | [diff] [blame] | 573 | |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 574 | // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 575 | CpuAuxTensorHandler workspace(offset_int_vec(AsmGemmWorkspace), _workspace_info, tensors, false); |
| 576 | if(workspace.get()->buffer() != nullptr) |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 577 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 578 | _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(workspace.get()->buffer())); |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 579 | const unsigned int split_dim = scheduling_hint.split_dimension(); |
| 580 | const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size(); |
| 581 | unsigned int num_threads = NEScheduler::get().num_threads(); |
| 582 | if(window_size < num_threads) |
| 583 | { |
| 584 | num_threads = window_size; |
| 585 | } |
| 586 | if(split_dim != IScheduler::split_dimensions_all) |
| 587 | { |
| 588 | // Make sure the kernel does not expect more threads than we can actually spawn |
| 589 | const unsigned int num_iterations = _optimised_kernel.get()->window().num_iterations(split_dim); |
| 590 | num_threads = std::min(num_iterations, num_threads); |
| 591 | } |
| 592 | _gemm_kernel_asm->set_nthreads(num_threads); |
| 593 | } |
| 594 | |
| 595 | // Prepare assembly kernel |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 596 | prepare(tensors); |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 597 | |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 598 | // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C. |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 599 | TypeOutput *bias = nullptr; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 600 | if(c && c->info()->data_type() != DataType::S32) |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 601 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 602 | bias = reinterpret_cast<TypeOutput *>(c->buffer() + c->info()->offset_first_element_in_bytes()); |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 603 | } |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 604 | |
| 605 | if(_gemm_info.method == AsmConvMethod::Indirect) |
| 606 | { |
| 607 | in0_ptr = nullptr; |
| 608 | lda = 0; |
| 609 | batch_stride_a = 0; |
| 610 | multi_stride_a = 0; |
| 611 | } |
| 612 | |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 613 | // Set gemm parameters |
| 614 | _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, |
| 615 | in1_ptr, ldb, multi_stride_b, |
| 616 | out_ptr, ldd, batch_stride_d, multi_stride_d, |
| 617 | bias, 0); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 618 | // Schedule |
Georgios Pinitas | 77d4252 | 2019-11-05 13:35:47 +0000 | [diff] [blame] | 619 | NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 620 | } |
| 621 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 622 | template <typename TypeInput, typename TypeOutput> |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 623 | void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, |
| 624 | const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, |
| 625 | arm_gemm::Activation activation, const AsmGemmInfo &info) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 626 | { |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 627 | Params p = extract_parameters(a, b, d, info); |
| 628 | const CPUInfo &ci = NEScheduler::get().cpu_info(); |
| 629 | unsigned int num_threads = NEScheduler::get().num_threads(); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 630 | |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 631 | arm_gemm::GemmConfig cfg; |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 632 | cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 633 | arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 634 | |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 635 | // Create arm_gemm fallback |
Georgios Pinitas | 40f51a6 | 2020-11-21 03:04:18 +0000 | [diff] [blame] | 636 | auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>(); |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 637 | fallback->configure(a, b, c, d, args, info); |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 638 | arm_gemm = std::move(fallback); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 639 | } |
| 640 | |
| 641 | template <typename TypeInput, typename TypeOutput> |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 642 | void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, |
| 643 | const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, |
| 644 | arm_gemm::Activation activation, const AsmGemmInfo &info) |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 645 | { |
Michele Di Giorgio | 6ad60af | 2020-06-09 14:52:15 +0100 | [diff] [blame] | 646 | ARM_COMPUTE_UNUSED(activation); |
Georgios Pinitas | 4ee8b15 | 2021-07-16 16:16:43 +0100 | [diff] [blame] | 647 | Params p = extract_parameters(a, b, d, info); |
| 648 | const CPUInfo &ci = NEScheduler::get().cpu_info(); |
| 649 | const unsigned int num_threads = NEScheduler::get().num_threads(); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 650 | |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 651 | arm_gemm::GemmConfig cfg; |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 652 | cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 653 | arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 654 | |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 655 | // Create arm_gemm fallback |
Georgios Pinitas | 40f51a6 | 2020-11-21 03:04:18 +0000 | [diff] [blame] | 656 | auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>(); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 657 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 658 | // Configure requantization info |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 659 | const int32_t negation = info.negated_offsets ? 1 : -1; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 660 | const int32_t a_offset = -a->quantization_info().uniform().offset * negation; |
| 661 | const int32_t b_offset = -b->quantization_info().uniform().offset * negation; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 662 | const GEMMLowpOutputStageInfo os_info = info.output_stage; |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 663 | |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 664 | arm_gemm::Requantize32 gemm_requant_info{}; |
| 665 | if(os_info.gemmlowp_shifts.size() > 1) |
| 666 | { |
| 667 | const auto requantize_data = fallback->set_requantize_data(os_info.gemmlowp_shifts, os_info.gemmlowp_multipliers); |
| 668 | gemm_requant_info = arm_gemm::Requantize32(nullptr, 0, |
| 669 | a_offset, b_offset, os_info.gemmlowp_offset, |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 670 | (std::get<0>(requantize_data)) ? std::get<1>(requantize_data) : nullptr, |
| 671 | std::get<2>(requantize_data), |
| 672 | std::get<3>(requantize_data), |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 673 | os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); |
| 674 | } |
| 675 | else |
| 676 | { |
| 677 | gemm_requant_info = arm_gemm::Requantize32(nullptr, 0, |
| 678 | a_offset, b_offset, os_info.gemmlowp_offset, |
| 679 | -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier, |
| 680 | os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); |
| 681 | } |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 682 | |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 683 | // Configure fallback |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 684 | fallback->configure(a, b, c, d, args, info, gemm_requant_info); |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 685 | arm_gemm = std::move(fallback); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 686 | } |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 687 | } //namespace |
| 688 | |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 689 | CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch() |
| 690 | : _arm_gemm(nullptr) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 691 | { |
| 692 | } |
| 693 | |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 694 | Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 695 | const AsmGemmInfo &info) |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 696 | { |
| 697 | ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); |
| 698 | ARM_COMPUTE_UNUSED(c); |
| 699 | arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(info.activation_info); |
| 700 | Params p = extract_parameters(a, b, d, info); |
| 701 | const CPUInfo &ci = NEScheduler::get().cpu_info(); |
| 702 | unsigned int num_threads = NEScheduler::get().num_threads(); |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 703 | arm_gemm::GemmConfig cfg; |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 704 | cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); |
| 705 | arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format); |
| 706 | arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 707 | switch(a->data_type()) |
| 708 | { |
| 709 | case DataType::F32: |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 710 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 711 | "We could not find an optimized kernel for F32 input"); |
| 712 | break; |
| 713 | #ifdef __aarch64__ |
| 714 | case DataType::U8: |
| 715 | case DataType::QASYMM8: |
| 716 | if(d->data_type() == DataType::S32) |
| 717 | { |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 718 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), |
Ramy Elgammal | c8cc024 | 2022-10-05 17:05:20 +0100 | [diff] [blame] | 719 | "We could not find an optimized kernel for U8/QASYMM8 input and U32 output"); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 720 | } |
| 721 | else |
| 722 | { |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 723 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})), |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 724 | "We could not find an optimized kernel for U8 input and U8 output"); |
| 725 | } |
| 726 | break; |
| 727 | case DataType::S8: |
| 728 | case DataType::QASYMM8_SIGNED: |
| 729 | if(d->data_type() == DataType::S32) |
| 730 | { |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 731 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 732 | "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output"); |
| 733 | } |
| 734 | else |
| 735 | { |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 736 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})), |
Ramy Elgammal | c8cc024 | 2022-10-05 17:05:20 +0100 | [diff] [blame] | 737 | "We could not find an optimized kernel for S8 input and S8 output"); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 738 | } |
| 739 | break; |
| 740 | #endif /* __aarch64__ */ |
Pablo Marquez Tello | d208f4f | 2022-07-19 12:19:46 +0100 | [diff] [blame] | 741 | #if defined(ARM_COMPUTE_ENABLE_BF16) |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 742 | case DataType::BFLOAT16: |
| 743 | { |
Ramy Elgammal | aa52b7d | 2022-07-27 10:44:04 +0100 | [diff] [blame] | 744 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 745 | "We could not find an optimized kernel for BFLOAT16 input and F32 output"); |
| 746 | break; |
| 747 | } |
Pablo Marquez Tello | d208f4f | 2022-07-19 12:19:46 +0100 | [diff] [blame] | 748 | #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 749 | #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC |
| 750 | case DataType::F16: |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 751 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), |
Ramy Elgammal | c8cc024 | 2022-10-05 17:05:20 +0100 | [diff] [blame] | 752 | "We could not find an optimized kernel for F16 input and F16 output"); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 753 | break; |
| 754 | #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ |
| 755 | default: |
| 756 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel"); |
| 757 | break; |
| 758 | } |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 759 | expected_weight_format = assembly_utils::map_to_arm_compute_weight_format(arm_gemm_expected_wf); |
Francesco.Petrogalli@arm.com | e33c556 | 2022-03-31 17:55:35 +0000 | [diff] [blame] | 760 | |
| 761 | return Status{}; |
| 762 | } |
| 763 | |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 764 | Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 765 | { |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 766 | ARM_COMPUTE_UNUSED(c, info); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 767 | ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d); |
| 768 | ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 769 | ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a); |
Mohammed Suhail Munshi | 4b5f6ef | 2022-10-21 11:15:54 +0100 | [diff] [blame] | 770 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(info.reshape_b_only_on_first_run), "Assembly kernel will not be executed when reshape_b_only_on_first_run is false"); |
Georgios Pinitas | 0f954eb | 2020-06-23 17:28:38 +0100 | [diff] [blame] | 771 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 772 | #ifndef __aarch64__ |
Michele Di Giorgio | 5255672 | 2019-12-23 16:35:12 +0000 | [diff] [blame] | 773 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64"); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 774 | #endif /* __aarch64__ */ |
Georgios Pinitas | dbdea0d | 2019-10-16 19:21:40 +0100 | [diff] [blame] | 775 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S8, |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 776 | DataType::BFLOAT16, DataType::F16, DataType::F32); |
Georgios Pinitas | dbdea0d | 2019-10-16 19:21:40 +0100 | [diff] [blame] | 777 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(b, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::S8, |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 778 | DataType::BFLOAT16, DataType::F16, DataType::F32); |
Georgios Pinitas | dbdea0d | 2019-10-16 19:21:40 +0100 | [diff] [blame] | 779 | if(is_data_type_quantized_per_channel(b->data_type())) |
| 780 | { |
| 781 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QASYMM8_SIGNED, DataType::S8); |
| 782 | } |
| 783 | else |
| 784 | { |
| 785 | ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b); |
| 786 | } |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 787 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F32 && d->data_type() != DataType::F32, "Only F32 output supported for F32 input"); |
| 788 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16, "Only F16 output supported for F16 input"); |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 789 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 && d->data_type() != DataType::F32, "Only F32 output supported for BFLOAT16 input"); |
Anthony Barbier | 9036749 | 2018-08-01 13:56:08 +0100 | [diff] [blame] | 790 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input"); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 791 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input"); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 792 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input"); |
Ramy Elgammal | 9178002 | 2022-07-20 14:57:37 +0100 | [diff] [blame] | 793 | arm_compute::WeightFormat expected_weight_format; |
| 794 | const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info); |
| 795 | if((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY) |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 796 | { |
| 797 | // Correctness check: if the format expected by the kernel is |
| 798 | // not "any", make sure that the one found matches the format |
| 799 | // intended by the caller. |
| 800 | ARM_COMPUTE_RETURN_ERROR_ON_MSG((expected_weight_format != info.weight_format), |
| 801 | "The format expected by the kernel does not correspond with the one requested by the user."); |
| 802 | } |
| 803 | return ret; |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 804 | } |
| 805 | |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 806 | bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation) |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 807 | { |
Michele Di Giorgio | d02d5ed | 2021-01-22 09:47:04 +0000 | [diff] [blame] | 808 | arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(activation); |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 809 | return act.type != arm_gemm::Activation::Type::None; |
| 810 | } |
| 811 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 812 | void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, const AsmGemmInfo &info) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 813 | { |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 814 | ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); |
Michele Di Giorgio | d02d5ed | 2021-01-22 09:47:04 +0000 | [diff] [blame] | 815 | arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(info.activation_info); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 816 | |
| 817 | //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured() |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 818 | if(!CpuGemmAssemblyDispatch::validate(a, b, c, d, info)) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 819 | { |
| 820 | return; |
| 821 | } |
| 822 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 823 | switch(a->data_type()) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 824 | { |
| 825 | case DataType::F32: |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 826 | create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 827 | break; |
| 828 | #ifdef __aarch64__ |
| 829 | case DataType::U8: |
| 830 | case DataType::QASYMM8: |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 831 | if(d->data_type() == DataType::S32) |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 832 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 833 | create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 834 | } |
| 835 | else |
| 836 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 837 | create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 838 | } |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 839 | break; |
| 840 | case DataType::S8: |
Georgios Pinitas | dbdea0d | 2019-10-16 19:21:40 +0100 | [diff] [blame] | 841 | case DataType::QASYMM8_SIGNED: |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 842 | if(d->data_type() == DataType::S32) |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 843 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 844 | create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 845 | } |
| 846 | else |
| 847 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 848 | create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 849 | } |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 850 | break; |
| 851 | #endif /* __aarch64__ */ |
Pablo Marquez Tello | d208f4f | 2022-07-19 12:19:46 +0100 | [diff] [blame] | 852 | #if defined(ARM_COMPUTE_ENABLE_BF16) |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 853 | case DataType::BFLOAT16: |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 854 | create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info); |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 855 | break; |
Pablo Marquez Tello | d208f4f | 2022-07-19 12:19:46 +0100 | [diff] [blame] | 856 | #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 857 | #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC |
| 858 | case DataType::F16: |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 859 | create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 860 | break; |
| 861 | #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ |
| 862 | default: |
| 863 | break; |
| 864 | } |
| 865 | } |
| 866 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 867 | void CpuGemmAssemblyDispatch::prepare(ITensorPack &tensors) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 868 | { |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 869 | ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 870 | _arm_gemm->prepare(tensors); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 871 | } |
| 872 | |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 873 | bool CpuGemmAssemblyDispatch::is_configured() const |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 874 | { |
Francesco Petrogalli | 553f695 | 2022-06-30 10:22:01 +0000 | [diff] [blame] | 875 | return _arm_gemm && _arm_gemm->is_configured(); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 876 | } |
| 877 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 878 | void CpuGemmAssemblyDispatch::run(ITensorPack &tensors) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 879 | { |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 880 | ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 881 | _arm_gemm->run(tensors); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 882 | } |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 883 | |
| 884 | experimental::MemoryRequirements CpuGemmAssemblyDispatch::workspace() const |
| 885 | { |
| 886 | ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); |
| 887 | return _arm_gemm->workspace(); |
| 888 | } |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 889 | } // namespace cpu |
| 890 | } // namespace arm_compute |