Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 1 | /* |
Francesco.Petrogalli@arm.com | 193cad3 | 2022-03-07 13:39:21 +0000 | [diff] [blame] | 2 | * Copyright (c) 2018-2022 Arm Limited. |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 3 | * |
| 4 | * SPDX-License-Identifier: MIT |
| 5 | * |
| 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | * of this software and associated documentation files (the "Software"), to |
| 8 | * deal in the Software without restriction, including without limitation the |
| 9 | * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 10 | * sell copies of the Software, and to permit persons to whom the Software is |
| 11 | * furnished to do so, subject to the following conditions: |
| 12 | * |
| 13 | * The above copyright notice and this permission notice shall be included in all |
| 14 | * copies or substantial portions of the Software. |
| 15 | * |
| 16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | * SOFTWARE. |
| 23 | */ |
Georgios Pinitas | 7891a73 | 2021-08-20 21:39:25 +0100 | [diff] [blame] | 24 | #include "src/cpu/operators/internal/CpuGemmAssemblyDispatch.h" |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 25 | |
Sang-Hoon Park | 68dd25f | 2020-10-19 16:00:11 +0100 | [diff] [blame] | 26 | #include "arm_compute/runtime/NEON/NEScheduler.h" |
| 27 | #include "src/core/CPP/Validate.h" |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 28 | #include "src/core/helpers/MemoryHelpers.h" |
Michele Di Giorgio | d02d5ed | 2021-01-22 09:47:04 +0000 | [diff] [blame] | 29 | #include "src/core/utils/AssemblyUtils.h" |
Georgios Pinitas | 7891a73 | 2021-08-20 21:39:25 +0100 | [diff] [blame] | 30 | #include "src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h" |
| 31 | #include "src/cpu/kernels/assembly/arm_gemm.hpp" |
| 32 | #include "src/cpu/utils/CpuAuxTensorHandler.h" |
Michele Di Giorgio | 6ad60af | 2020-06-09 14:52:15 +0100 | [diff] [blame] | 33 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 34 | #include <arm_neon.h> |
| 35 | |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 36 | namespace arm_compute |
| 37 | { |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 38 | namespace cpu |
| 39 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 40 | using namespace arm_compute::experimental; |
| 41 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 42 | namespace |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 43 | { |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 44 | struct free_delete |
| 45 | { |
| 46 | void operator()(void *x) |
| 47 | { |
| 48 | free(x); |
| 49 | } |
| 50 | }; |
| 51 | |
| 52 | struct Params |
| 53 | { |
| 54 | unsigned int M; |
| 55 | unsigned int N; |
| 56 | unsigned int K; |
| 57 | unsigned int batches; |
| 58 | unsigned int multis; |
| 59 | unsigned int sections; |
| 60 | bool indirect; |
| 61 | }; |
| 62 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 63 | Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 64 | { |
| 65 | ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 66 | Params p; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 67 | p.M = d->tensor_shape().y(); |
| 68 | p.K = a->tensor_shape().x(); |
| 69 | p.N = d->tensor_shape().x(); |
Georgios Pinitas | 4c634e0 | 2020-12-01 02:17:19 +0000 | [diff] [blame] | 70 | p.batches = 1; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 71 | p.multis = 1; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 72 | p.sections = 1; |
Georgios Pinitas | 4c634e0 | 2020-12-01 02:17:19 +0000 | [diff] [blame] | 73 | p.indirect = false; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 74 | |
| 75 | if(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect) |
| 76 | { |
| 77 | p.indirect = true; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 78 | p.sections = b->tensor_shape()[2] * b->tensor_shape()[3]; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 79 | } |
| 80 | else |
| 81 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 82 | p.multis = b->tensor_shape().z(); |
| 83 | p.batches = d->tensor_shape().total_size_upper(2) / p.multis; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 84 | } |
| 85 | |
| 86 | // Update M in case of GEMM3D for output |
| 87 | if(info.depth_output_gemm3d != 0) |
| 88 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 89 | p.M = d->tensor_shape().y() * d->tensor_shape().z(); |
| 90 | p.batches = d->tensor_shape().total_size_upper(3) / p.multis; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 91 | } |
| 92 | |
| 93 | return p; |
| 94 | } |
| 95 | |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 96 | IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataType data_type) |
| 97 | { |
| 98 | // Schedule assembly kernel |
| 99 | const int granule_threshold = 200; |
| 100 | IScheduler::Hints scheduling_hint = IScheduler::Hints(Window::DimX); |
| 101 | if(method == arm_gemm::GemmMethod::GEMM_INTERLEAVED && data_type == DataType::F32) |
| 102 | { |
| 103 | scheduling_hint = IScheduler::Hints(Window::DimX, IScheduler::StrategyHint::DYNAMIC, granule_threshold); |
| 104 | } |
| 105 | else if(method == arm_gemm::GemmMethod::GEMM_INTERLEAVED_2D && (data_type == DataType::F32 || data_type == DataType::F16 || data_type == DataType::U8 || data_type == DataType::S8)) |
| 106 | { |
| 107 | //GEMM_INTERLEAVED supports 2D parallelism, IScheduler::split_dimensions_all signals to parallelise over all window dimensions |
| 108 | scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold); |
| 109 | } |
| 110 | else if(method == arm_gemm::GemmMethod::QUANTIZE_WRAPPER_2D && (data_type == DataType::QASYMM8 || data_type == DataType::QASYMM8_SIGNED)) |
| 111 | { |
| 112 | //special case for QASYMM8 to support 2D parallelism, scheduler here may be tweaked differently compared to FP32 case |
| 113 | scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold); |
| 114 | } |
| 115 | |
| 116 | return scheduling_hint; |
| 117 | } |
| 118 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 119 | /** Fallback in case ACL doesn't have a function */ |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 120 | template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing> |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 121 | class Fallback : public CpuGemmAssemblyDispatch::IFallback |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 122 | { |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 123 | public: |
Michalis Spyrou | 1a569a3 | 2019-09-10 17:20:34 +0100 | [diff] [blame] | 124 | /** Destructor */ |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 125 | ~Fallback() = default; |
Michalis Spyrou | 1a569a3 | 2019-09-10 17:20:34 +0100 | [diff] [blame] | 126 | |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 127 | /** Initialise the functions's input and output. |
| 128 | * |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 129 | * @param[in] a Input tensor containing the Matrix A. |
| 130 | * @param[in] b Input tensor containing the Matrix B. |
| 131 | * @param[in] c Input tensor containing the Matrix C. |
| 132 | * @param[out] d Output tensor to store the result of matrix multiplication. |
| 133 | * @param[in] args Matrix multiplication information. |
| 134 | * @param[in] gemm_info GEMM meta-data |
| 135 | * @param[in] os Output stage meta-data. |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 136 | */ |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 137 | void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 138 | arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info, |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 139 | const OutputStage &os = {}); |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 140 | |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 141 | /** Set requantization shifts to be used |
| 142 | * |
| 143 | * @param[in] shifts Requantization shifts |
| 144 | * |
| 145 | * @return Pointer to the shift data |
| 146 | */ |
| 147 | /** Set requantization data to be used |
| 148 | * |
| 149 | * |
| 150 | * @param shifts Requantization shifts |
| 151 | * @param multipliers Requantization multipliers |
| 152 | * |
| 153 | * @return A tuple with the pointers to the shift and multiplier data respectively |
| 154 | */ |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 155 | std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> set_requantize_data(const std::vector<int32_t> &shifts, |
| 156 | const std::vector<int32_t> &multipliers); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 157 | |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 158 | // Inherited methods overridden: |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 159 | void run(ITensorPack &tensors) override; |
| 160 | void prepare(ITensorPack &tensors) override; |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 161 | bool is_configured() const override; |
| 162 | experimental::MemoryRequirements workspace() const override; |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 163 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 164 | private: |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 165 | enum AuxTensorIdx |
| 166 | { |
| 167 | AsmGemmWorkspace = 0, |
| 168 | Pretranspose, |
| 169 | Count |
| 170 | }; |
| 171 | |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 172 | /** Configure the indirect buffer |
| 173 | * |
| 174 | * @param[in] a Input tensor containing the Matrix A. |
| 175 | * @param[in] b Input tensor containing the Matrix B. |
| 176 | * @param[out] d Output tensor to store the result of matrix multiplication. |
| 177 | * @param[in] info GEMM meta-data |
| 178 | */ |
| 179 | void configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info); |
| 180 | /** Prepare the indirect buffer */ |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 181 | void prepare_indirect_buffer(ITensorPack &tensors); |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 182 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 183 | /** Assembly Gemm kernel */ |
Michalis Spyrou | 1a569a3 | 2019-09-10 17:20:34 +0100 | [diff] [blame] | 184 | std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr }; |
Michele Di Giorgio | 33f41fa | 2021-03-09 14:09:08 +0000 | [diff] [blame] | 185 | /** Optimised Arm® Neon™ kernel */ |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 186 | std::unique_ptr<INEKernel> _optimised_kernel{ nullptr }; |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 187 | /** Assembly GEMM workspace tensor info */ |
| 188 | TensorInfo _workspace_info{}; |
| 189 | /** Pre-transpose tensor info */ |
| 190 | TensorInfo _pretranspose_info{}; |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 191 | /** Prepared flag */ |
| 192 | bool _is_prepared{ false }; |
Georgios Pinitas | 37d080f | 2019-06-21 18:43:12 +0100 | [diff] [blame] | 193 | /** GEMM meta-data */ |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 194 | AsmGemmInfo _gemm_info{}; |
Georgios Pinitas | 77d4252 | 2019-11-05 13:35:47 +0000 | [diff] [blame] | 195 | /** GEMM kernel description */ |
| 196 | arm_gemm::KernelDescription _kernel_info{}; |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 197 | /** Per channel quantization shifts */ |
| 198 | std::vector<int32_t> _shifts{}; |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 199 | std::vector<int32_t> right_shifts{}; |
| 200 | std::vector<int32_t> left_shifts{}; |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 201 | /** Per channel quantization multipliers */ |
| 202 | std::vector<int32_t> _multipliers{}; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 203 | /** Indirect buffer */ |
| 204 | std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{}; |
| 205 | std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{}; |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 206 | std::vector<TypeInput> _indirect_pad{}; |
| 207 | arm_gemm::ConvolutionParameters _cp{}; |
| 208 | experimental::MemoryRequirements _aux_mem{ Count }; |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 209 | bool _B_pretranspose_required{ false }; |
| 210 | bool _is_b_constant{ true }; |
| 211 | bool _is_c_constant{ true }; |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 212 | }; |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 213 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 214 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 215 | std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> |
| 216 | Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, const std::vector<int32_t> &multipliers) |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 217 | { |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 218 | _multipliers = multipliers; |
| 219 | _shifts = shifts; |
| 220 | bool need_left = false; |
| 221 | for(const auto s : _shifts) |
| 222 | { |
| 223 | left_shifts.push_back(std::max(-s, int32_t(0))); |
| 224 | right_shifts.push_back(std::min(-s, int32_t(0))); |
morgolock | fa269bb | 2020-09-08 16:00:56 +0100 | [diff] [blame] | 225 | if(s < 0 && !need_left) |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 226 | { |
| 227 | need_left = true; |
| 228 | } |
| 229 | } |
| 230 | return std::make_tuple(need_left, left_shifts.data(), right_shifts.data(), _multipliers.data()); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 231 | } |
| 232 | |
| 233 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 234 | void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors) |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 235 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 236 | auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0); |
| 237 | const TypeInput *A_ptr = reinterpret_cast<TypeInput *>(a->buffer()); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 238 | const int multis = 1; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 239 | const int batches = a->info()->tensor_shape().total_size_upper(3); |
| 240 | const size_t stride_A = a->info()->strides_in_bytes().y() / sizeof(TypeInput); |
| 241 | const size_t batch_stride_A = a->info()->strides_in_bytes()[3] / sizeof(TypeInput); |
| 242 | const size_t multi_stride_A = a->info()->strides_in_bytes()[4] / sizeof(TypeInput); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 243 | |
| 244 | const size_t output_hw = _cp.output_height * _cp.output_width; |
| 245 | const int batch_size = _cp.kernel_height * _cp.kernel_width * output_hw * sizeof(TypeInput); |
| 246 | const size_t batch_stride = batch_size / sizeof(TypeInput); |
| 247 | const int multi_size = batch_size * batches; |
| 248 | const size_t multi_stride = multi_size / sizeof(TypeInput); |
| 249 | |
| 250 | for(int64_t m = 0; m < multis; m++) |
| 251 | { |
| 252 | for(int64_t b = 0; b < batches; b++) |
| 253 | { |
| 254 | for(int64_t output_y = 0; output_y < _cp.output_height; output_y++) |
| 255 | { |
| 256 | for(int64_t output_x = 0; output_x < _cp.output_width; output_x++) |
| 257 | { |
| 258 | int64_t output_xy = (output_y * _cp.output_width) + output_x; |
| 259 | |
| 260 | for(int64_t kernel_y = 0; kernel_y < _cp.kernel_height; kernel_y++) |
| 261 | { |
| 262 | for(int64_t kernel_x = 0; kernel_x < _cp.kernel_width; kernel_x++) |
| 263 | { |
| 264 | int64_t input_x = (output_x * _cp.output_stride_w) + kernel_x - _cp.padding_left; |
| 265 | int64_t input_y = (output_y * _cp.output_stride_h) + kernel_y - _cp.padding_top; |
| 266 | int64_t kernel_xy = (kernel_y * _cp.kernel_width) + kernel_x; |
| 267 | int64_t input_xy = (input_y * _cp.input_width) + input_x; |
| 268 | |
| 269 | if(input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height) |
| 270 | { |
| 271 | _indirect_buf.get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = _indirect_pad.data(); |
| 272 | } |
| 273 | else |
| 274 | { |
| 275 | _indirect_buf.get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = |
| 276 | A_ptr + (m * multi_stride_A + b * batch_stride_A + input_xy * stride_A); |
| 277 | } |
| 278 | } |
| 279 | } |
| 280 | } |
| 281 | } |
| 282 | } |
| 283 | } |
| 284 | } |
| 285 | |
| 286 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
| 287 | void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info) |
| 288 | { |
| 289 | ARM_COMPUTE_ERROR_ON(!(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)); |
| 290 | |
| 291 | float zeropad = 0.f; |
| 292 | if(is_data_type_quantized(a->data_type())) |
| 293 | { |
| 294 | zeropad = a->quantization_info().uniform().offset; |
| 295 | } |
| 296 | |
| 297 | const int64_t input_width = static_cast<int64_t>(a->tensor_shape()[1]); |
| 298 | const int64_t input_height = static_cast<int64_t>(a->tensor_shape()[2]); |
| 299 | const int64_t input_channels = static_cast<int64_t>(a->tensor_shape()[0]); |
| 300 | const int64_t kernel_width = static_cast<int64_t>(b->tensor_shape()[2]); |
| 301 | const int64_t kernel_height = static_cast<int64_t>(b->tensor_shape()[3]); |
| 302 | const int64_t output_width = static_cast<int64_t>(d->tensor_shape()[1]); |
| 303 | const int64_t output_height = static_cast<int64_t>(d->tensor_shape()[2]); |
| 304 | |
| 305 | _cp = { input_width, input_height, input_channels, kernel_width, kernel_height, output_width, output_height, |
| 306 | info.ps_info.stride().first, info.ps_info.stride().second, info.padding_top, info.padding_left, zeropad |
| 307 | }; |
| 308 | |
| 309 | if(info.method == AsmConvMethod::Conv) |
| 310 | { |
| 311 | _gemm_kernel_asm->set_convolution_parameters(_cp); |
| 312 | } |
| 313 | |
| 314 | if(info.method == AsmConvMethod::Indirect) |
| 315 | { |
| 316 | const unsigned int multis = 1; |
| 317 | const unsigned int batches = a->tensor_shape().total_size_upper(3); |
| 318 | const unsigned int kernel_hw = _cp.kernel_width * _cp.kernel_height; |
| 319 | const unsigned int output_hw = _cp.output_width * _cp.output_height; |
| 320 | |
| 321 | using TypeInputPtr = TypeInput *; |
| 322 | const int batch_size = kernel_hw * output_hw * sizeof(TypeInputPtr); |
| 323 | const size_t batch_stride = batch_size / sizeof(TypeInputPtr); |
| 324 | const int multi_size = batch_size * batches; |
| 325 | const size_t multi_stride = multi_size / sizeof(TypeInputPtr); |
| 326 | |
| 327 | _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>(reinterpret_cast<const TypeInput **>(malloc(multi_size * multis))); |
| 328 | _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>(reinterpret_cast<const TypeInput *const **>(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches))); |
Sang-Hoon Park | 8d5337e | 2021-01-15 14:36:25 +0000 | [diff] [blame] | 329 | _indirect_pad = std::vector<TypeInput>(_cp.input_channels, TypeInput(zeropad)); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 330 | |
| 331 | // Set indirect argument |
| 332 | int64_t pos = 0; |
| 333 | for(int64_t m = 0; m < multis; m++) |
| 334 | { |
| 335 | for(int64_t b = 0; b < batches; b++) |
| 336 | { |
| 337 | for(int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++) |
| 338 | { |
| 339 | (_indirect_arg.get())[pos++] = _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw; |
| 340 | } |
| 341 | } |
| 342 | } |
| 343 | |
| 344 | _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.get()); |
| 345 | } |
| 346 | } |
| 347 | |
| 348 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 349 | void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 350 | arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info, |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 351 | const OutputStage &os) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 352 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 353 | ARM_COMPUTE_UNUSED(c); |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 354 | |
| 355 | _is_b_constant = b->are_values_constant(); |
| 356 | _is_c_constant = c ? c->are_values_constant() : true; |
| 357 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 358 | _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput, OutputStage>(args, os); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 359 | if(_gemm_kernel_asm == nullptr) |
| 360 | { |
| 361 | //configuration not supported: Leave function unconfigured: |
| 362 | return; |
| 363 | } |
| 364 | |
Francesco.Petrogalli@arm.com | 193cad3 | 2022-03-07 13:39:21 +0000 | [diff] [blame] | 365 | arm_gemm::GemmConfig gemm_cfg = _gemm_kernel_asm->get_config(); |
| 366 | |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 367 | // arm_compute wrapper for the Gemm object (see above) |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 368 | auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>(); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 369 | ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr); |
Georgios Pinitas | 3dbfd23 | 2019-01-30 17:17:16 +0000 | [diff] [blame] | 370 | acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter); |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 371 | const size_t workspace_size = _gemm_kernel_asm->get_working_size(); |
| 372 | const unsigned int alignment = 4096; |
| 373 | _workspace_info = TensorInfo(TensorShape(workspace_size), 1, DataType::U8); |
| 374 | _aux_mem[AsmGemmWorkspace] = MemoryInfo(offset_int_vec(AsmGemmWorkspace), MemoryLifetime::Temporary, workspace_size, alignment); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 375 | |
| 376 | //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and |
| 377 | //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001 |
| 378 | { |
Georgios Pinitas | 5aa1a0b | 2020-07-02 20:02:20 +0100 | [diff] [blame] | 379 | const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size(); |
Joseph Dobson | 6f8b17d | 2020-02-11 19:32:11 +0000 | [diff] [blame] | 380 | if(window_size < static_cast<unsigned int>(args._maxthreads)) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 381 | { |
Anthony Barbier | c8e84b5 | 2018-07-17 16:48:42 +0100 | [diff] [blame] | 382 | _gemm_kernel_asm->set_nthreads(window_size); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 383 | } |
| 384 | } |
| 385 | |
| 386 | _optimised_kernel = std::move(acl_gemm_wrapper); |
Georgios Pinitas | 37d080f | 2019-06-21 18:43:12 +0100 | [diff] [blame] | 387 | _gemm_info = gemm_info; |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 388 | // Check for pre-transposed support |
| 389 | if(_gemm_kernel_asm->B_pretranspose_required()) |
| 390 | { |
| 391 | // Forcing 128-byte alignment (required by 32-bit kernels) |
| 392 | const unsigned int alignment = 128; |
| 393 | const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size(); |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 394 | _pretranspose_info = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8); |
| 395 | _aux_mem[Pretranspose] = MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment); |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 396 | _B_pretranspose_required = true; |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 397 | } |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 398 | |
| 399 | // Handle indirect GEMM convolution |
| 400 | if(gemm_info.method == AsmConvMethod::Conv || gemm_info.method == AsmConvMethod::Indirect) |
| 401 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 402 | configure_indirect(a, b, d, gemm_info); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 403 | } |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 404 | } |
| 405 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 406 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 407 | void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 408 | { |
| 409 | if(!_is_prepared) |
| 410 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 411 | auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1); |
| 412 | auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2); |
| 413 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 414 | // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C. |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 415 | if(c && c->info()->data_type() == DataType::S32) |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 416 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 417 | _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 418 | } |
| 419 | |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 420 | // Pretranspose B if required |
| 421 | if(_gemm_kernel_asm->B_pretranspose_required()) |
| 422 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 423 | const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); |
| 424 | const auto in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes()); |
| 425 | const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 426 | |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 427 | CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false); |
| 428 | ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr); |
| 429 | _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), in1_ptr, ldb, multi_stride_b); |
Georgios Pinitas | fa1db17 | 2021-08-12 06:28:09 +0100 | [diff] [blame] | 430 | |
| 431 | b->mark_as_unused(); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 432 | } |
| 433 | |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 434 | if(_gemm_info.method == AsmConvMethod::Indirect) |
| 435 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 436 | prepare_indirect_buffer(tensors); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 437 | } |
| 438 | |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 439 | _is_prepared = true; |
| 440 | } |
| 441 | } |
| 442 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 443 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 444 | bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 445 | { |
| 446 | return _optimised_kernel != nullptr; |
| 447 | } |
| 448 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 449 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 450 | experimental::MemoryRequirements Fallback<TypeInput, TypeOutput, OutputStage>::workspace() const |
| 451 | { |
| 452 | return _aux_mem; |
| 453 | } |
| 454 | |
| 455 | template <typename TypeInput, typename TypeOutput, class OutputStage> |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 456 | void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 457 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 458 | auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0); |
| 459 | auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1); |
| 460 | auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2); |
| 461 | auto d = tensors.get_tensor(TensorType::ACL_DST); |
| 462 | |
| 463 | int lda = a->info()->strides_in_bytes().y() / sizeof(TypeInput); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 464 | int ldb = 0; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 465 | const int ldd = d->info()->strides_in_bytes().y() / sizeof(TypeOutput); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 466 | |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 467 | const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d != 0 ? 3 : 2; |
Georgios Pinitas | 37d080f | 2019-06-21 18:43:12 +0100 | [diff] [blame] | 468 | const size_t a_multi_idx = a_batch_idx + 1; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 469 | const size_t d_batch_idx = _gemm_info.depth_output_gemm3d != 0 ? 3 : 2; |
Georgios Pinitas | 37d080f | 2019-06-21 18:43:12 +0100 | [diff] [blame] | 470 | const size_t d_multi_idx = d_batch_idx + 1; |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 471 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 472 | int batch_stride_a = a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput); |
| 473 | const int batch_stride_d = d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 474 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 475 | int multi_stride_a = a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 476 | int multi_stride_b = 0; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 477 | const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 478 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 479 | auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes()); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 480 | const TypeInput *in1_ptr = nullptr; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 481 | auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes()); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 482 | |
| 483 | // Check if B is pre-tranposed and de-reference if not |
| 484 | if(!_gemm_kernel_asm->B_is_pretransposed()) |
| 485 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 486 | ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); |
| 487 | multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); |
| 488 | in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes()); |
Georgios Pinitas | 40ed6d8 | 2018-07-31 17:22:11 +0100 | [diff] [blame] | 489 | } |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 490 | |
Giorgio Arena | 63e0beb | 2021-09-24 14:04:27 +0100 | [diff] [blame] | 491 | // If necessary, run pretranspose every time if either weights or biases are non-constant |
| 492 | if((b && !_is_b_constant) || (c && !_is_c_constant && c->info()->data_type() == DataType::S32)) |
| 493 | { |
| 494 | if(c && c->info()->data_type() == DataType::S32) |
| 495 | { |
| 496 | _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0); |
| 497 | } |
| 498 | |
| 499 | // Pretranspose B if required |
| 500 | if(_B_pretranspose_required) |
| 501 | { |
| 502 | const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); |
| 503 | const auto b_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes()); |
| 504 | const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); |
| 505 | |
| 506 | CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true); |
| 507 | ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr); |
| 508 | |
| 509 | if(_is_b_constant) |
| 510 | { |
| 511 | _gemm_kernel_asm->requantize_bias(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b); |
| 512 | } |
| 513 | else |
| 514 | { |
| 515 | _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b); |
| 516 | } |
| 517 | } |
| 518 | } |
| 519 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 520 | const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, d->info()->data_type()); |
Joseph Dobson | 6f8b17d | 2020-02-11 19:32:11 +0000 | [diff] [blame] | 521 | |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 522 | // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 523 | CpuAuxTensorHandler workspace(offset_int_vec(AsmGemmWorkspace), _workspace_info, tensors, false); |
| 524 | if(workspace.get()->buffer() != nullptr) |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 525 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 526 | _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(workspace.get()->buffer())); |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 527 | const unsigned int split_dim = scheduling_hint.split_dimension(); |
| 528 | const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size(); |
| 529 | unsigned int num_threads = NEScheduler::get().num_threads(); |
| 530 | if(window_size < num_threads) |
| 531 | { |
| 532 | num_threads = window_size; |
| 533 | } |
| 534 | if(split_dim != IScheduler::split_dimensions_all) |
| 535 | { |
| 536 | // Make sure the kernel does not expect more threads than we can actually spawn |
| 537 | const unsigned int num_iterations = _optimised_kernel.get()->window().num_iterations(split_dim); |
| 538 | num_threads = std::min(num_iterations, num_threads); |
| 539 | } |
| 540 | _gemm_kernel_asm->set_nthreads(num_threads); |
| 541 | } |
| 542 | |
| 543 | // Prepare assembly kernel |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 544 | prepare(tensors); |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 545 | |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 546 | // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C. |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 547 | TypeOutput *bias = nullptr; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 548 | if(c && c->info()->data_type() != DataType::S32) |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 549 | { |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 550 | bias = reinterpret_cast<TypeOutput *>(c->buffer() + c->info()->offset_first_element_in_bytes()); |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 551 | } |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 552 | |
| 553 | if(_gemm_info.method == AsmConvMethod::Indirect) |
| 554 | { |
| 555 | in0_ptr = nullptr; |
| 556 | lda = 0; |
| 557 | batch_stride_a = 0; |
| 558 | multi_stride_a = 0; |
| 559 | } |
| 560 | |
David Mansell | 9e698d5 | 2020-08-25 15:02:02 +0100 | [diff] [blame] | 561 | // Set gemm parameters |
| 562 | _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, |
| 563 | in1_ptr, ldb, multi_stride_b, |
| 564 | out_ptr, ldd, batch_stride_d, multi_stride_d, |
| 565 | bias, 0); |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 566 | // Schedule |
Georgios Pinitas | 77d4252 | 2019-11-05 13:35:47 +0000 | [diff] [blame] | 567 | NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint); |
Anthony Barbier | 71d9b57 | 2018-07-06 17:05:59 +0100 | [diff] [blame] | 568 | } |
| 569 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 570 | template <typename TypeInput, typename TypeOutput> |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 571 | void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, |
| 572 | const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, |
| 573 | arm_gemm::Activation activation, const AsmGemmInfo &info) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 574 | { |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 575 | Params p = extract_parameters(a, b, d, info); |
| 576 | const CPUInfo &ci = NEScheduler::get().cpu_info(); |
| 577 | unsigned int num_threads = NEScheduler::get().num_threads(); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 578 | |
Georgios Pinitas | 4ee8b15 | 2021-07-16 16:16:43 +0100 | [diff] [blame] | 579 | arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fast_mode); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 580 | |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 581 | // Create arm_gemm fallback |
Georgios Pinitas | 40f51a6 | 2020-11-21 03:04:18 +0000 | [diff] [blame] | 582 | auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>(); |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 583 | fallback->configure(a, b, c, d, args, info); |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 584 | arm_gemm = std::move(fallback); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 585 | } |
| 586 | |
| 587 | template <typename TypeInput, typename TypeOutput> |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 588 | void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, |
| 589 | const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, |
| 590 | arm_gemm::Activation activation, const AsmGemmInfo &info) |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 591 | { |
Michele Di Giorgio | 6ad60af | 2020-06-09 14:52:15 +0100 | [diff] [blame] | 592 | ARM_COMPUTE_UNUSED(activation); |
Georgios Pinitas | 4ee8b15 | 2021-07-16 16:16:43 +0100 | [diff] [blame] | 593 | Params p = extract_parameters(a, b, d, info); |
| 594 | const CPUInfo &ci = NEScheduler::get().cpu_info(); |
| 595 | const unsigned int num_threads = NEScheduler::get().num_threads(); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 596 | |
Georgios Pinitas | 4ee8b15 | 2021-07-16 16:16:43 +0100 | [diff] [blame] | 597 | arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fast_mode); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 598 | |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 599 | // Create arm_gemm fallback |
Georgios Pinitas | 40f51a6 | 2020-11-21 03:04:18 +0000 | [diff] [blame] | 600 | auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>(); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 601 | |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 602 | // Configure requantization info |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 603 | const int32_t negation = info.negated_offsets ? 1 : -1; |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 604 | const int32_t a_offset = -a->quantization_info().uniform().offset * negation; |
| 605 | const int32_t b_offset = -b->quantization_info().uniform().offset * negation; |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 606 | const GEMMLowpOutputStageInfo os_info = info.output_stage; |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 607 | |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 608 | arm_gemm::Requantize32 gemm_requant_info{}; |
| 609 | if(os_info.gemmlowp_shifts.size() > 1) |
| 610 | { |
| 611 | const auto requantize_data = fallback->set_requantize_data(os_info.gemmlowp_shifts, os_info.gemmlowp_multipliers); |
| 612 | gemm_requant_info = arm_gemm::Requantize32(nullptr, 0, |
| 613 | a_offset, b_offset, os_info.gemmlowp_offset, |
morgolock | 0bc80da | 2020-08-10 16:44:18 +0100 | [diff] [blame] | 614 | (std::get<0>(requantize_data)) ? std::get<1>(requantize_data) : nullptr, |
| 615 | std::get<2>(requantize_data), |
| 616 | std::get<3>(requantize_data), |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 617 | os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); |
| 618 | } |
| 619 | else |
| 620 | { |
| 621 | gemm_requant_info = arm_gemm::Requantize32(nullptr, 0, |
| 622 | a_offset, b_offset, os_info.gemmlowp_offset, |
| 623 | -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier, |
| 624 | os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); |
| 625 | } |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 626 | |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 627 | // Configure fallback |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 628 | fallback->configure(a, b, c, d, args, info, gemm_requant_info); |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 629 | arm_gemm = std::move(fallback); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 630 | } |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 631 | } //namespace |
| 632 | |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 633 | CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch() |
| 634 | : _arm_gemm(nullptr) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 635 | { |
| 636 | } |
| 637 | |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 638 | Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 639 | { |
Georgios Pinitas | c0b6f76 | 2020-11-02 01:37:17 +0000 | [diff] [blame] | 640 | ARM_COMPUTE_UNUSED(c, info); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 641 | ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d); |
| 642 | ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 643 | ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a); |
Georgios Pinitas | 0f954eb | 2020-06-23 17:28:38 +0100 | [diff] [blame] | 644 | |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 645 | #ifndef __aarch64__ |
Michele Di Giorgio | 5255672 | 2019-12-23 16:35:12 +0000 | [diff] [blame] | 646 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64"); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 647 | #endif /* __aarch64__ */ |
Georgios Pinitas | dbdea0d | 2019-10-16 19:21:40 +0100 | [diff] [blame] | 648 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S8, |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 649 | DataType::BFLOAT16, DataType::F16, DataType::F32); |
Georgios Pinitas | dbdea0d | 2019-10-16 19:21:40 +0100 | [diff] [blame] | 650 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(b, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::S8, |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 651 | DataType::BFLOAT16, DataType::F16, DataType::F32); |
Georgios Pinitas | dbdea0d | 2019-10-16 19:21:40 +0100 | [diff] [blame] | 652 | if(is_data_type_quantized_per_channel(b->data_type())) |
| 653 | { |
| 654 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QASYMM8_SIGNED, DataType::S8); |
| 655 | } |
| 656 | else |
| 657 | { |
| 658 | ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b); |
| 659 | } |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 660 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F32 && d->data_type() != DataType::F32, "Only F32 output supported for F32 input"); |
| 661 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16, "Only F16 output supported for F16 input"); |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 662 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 && d->data_type() != DataType::F32, "Only F32 output supported for BFLOAT16 input"); |
Anthony Barbier | 9036749 | 2018-08-01 13:56:08 +0100 | [diff] [blame] | 663 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input"); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 664 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input"); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 665 | ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input"); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 666 | return Status{}; |
| 667 | } |
| 668 | |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 669 | bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation) |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 670 | { |
Michele Di Giorgio | d02d5ed | 2021-01-22 09:47:04 +0000 | [diff] [blame] | 671 | arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(activation); |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 672 | return act.type != arm_gemm::Activation::Type::None; |
| 673 | } |
| 674 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 675 | void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, const AsmGemmInfo &info) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 676 | { |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 677 | ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); |
Michele Di Giorgio | d02d5ed | 2021-01-22 09:47:04 +0000 | [diff] [blame] | 678 | arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(info.activation_info); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 679 | |
| 680 | //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured() |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 681 | if(!CpuGemmAssemblyDispatch::validate(a, b, c, d, info)) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 682 | { |
| 683 | return; |
| 684 | } |
| 685 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 686 | switch(a->data_type()) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 687 | { |
| 688 | case DataType::F32: |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 689 | create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 690 | break; |
| 691 | #ifdef __aarch64__ |
| 692 | case DataType::U8: |
| 693 | case DataType::QASYMM8: |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 694 | if(d->data_type() == DataType::S32) |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 695 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 696 | create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 697 | } |
| 698 | else |
| 699 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 700 | create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info); |
Georgios Pinitas | cfa2bba | 2019-06-27 17:00:52 +0100 | [diff] [blame] | 701 | } |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 702 | break; |
| 703 | case DataType::S8: |
Georgios Pinitas | dbdea0d | 2019-10-16 19:21:40 +0100 | [diff] [blame] | 704 | case DataType::QASYMM8_SIGNED: |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 705 | if(d->data_type() == DataType::S32) |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 706 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 707 | create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 708 | } |
| 709 | else |
| 710 | { |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 711 | create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info); |
Michalis Spyrou | 71ac903 | 2019-11-14 14:31:44 +0000 | [diff] [blame] | 712 | } |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 713 | break; |
| 714 | #endif /* __aarch64__ */ |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 715 | #if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) |
| 716 | case DataType::BFLOAT16: |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 717 | create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info); |
Georgios Pinitas | c7b183a | 2020-03-06 18:12:09 +0000 | [diff] [blame] | 718 | break; |
| 719 | #endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 720 | #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC |
| 721 | case DataType::F16: |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 722 | create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 723 | break; |
| 724 | #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ |
| 725 | default: |
| 726 | break; |
| 727 | } |
| 728 | } |
| 729 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 730 | void CpuGemmAssemblyDispatch::prepare(ITensorPack &tensors) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 731 | { |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 732 | ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 733 | _arm_gemm->prepare(tensors); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 734 | } |
| 735 | |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 736 | bool CpuGemmAssemblyDispatch::is_configured() const |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 737 | { |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 738 | return _arm_gemm != nullptr && _arm_gemm->is_configured(); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 739 | } |
| 740 | |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 741 | void CpuGemmAssemblyDispatch::run(ITensorPack &tensors) |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 742 | { |
Georgios Pinitas | 48b3ef8 | 2019-10-14 19:03:09 +0100 | [diff] [blame] | 743 | ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); |
Sang-Hoon Park | d89e2fa | 2021-05-17 17:04:50 +0100 | [diff] [blame] | 744 | _arm_gemm->run(tensors); |
Anthony Barbier | eaefd00 | 2018-07-20 17:49:35 +0100 | [diff] [blame] | 745 | } |
Michele Di Giorgio | d7316eb | 2021-06-16 11:14:41 +0100 | [diff] [blame] | 746 | |
| 747 | experimental::MemoryRequirements CpuGemmAssemblyDispatch::workspace() const |
| 748 | { |
| 749 | ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); |
| 750 | return _arm_gemm->workspace(); |
| 751 | } |
Sang-Hoon Park | 4f7693d | 2021-05-12 13:59:10 +0100 | [diff] [blame] | 752 | } // namespace cpu |
| 753 | } // namespace arm_compute |