blob: c7adf8e4accf326cd7f40cbdb96ce9e44ef58e36 [file] [log] [blame]
Pablo Telloeb82fd22018-02-23 13:43:50 +00001/*
David Mansellaaa9da12023-03-10 13:48:50 +00002 * Copyright (c) 2017-2020, 2022-2023 Arm Limited.
Pablo Telloeb82fd22018-02-23 13:43:50 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Pablo Tello99ef8402018-03-20 16:46:55 +000024
25// This can only be built if the target/compiler supports FP16 arguments.
Georgios Pinitas4ee8b152021-07-16 16:16:43 +010026#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC))
Pablo Telloeb82fd22018-02-23 13:43:50 +000027
28#include "arm_gemm.hpp"
29
30#include "gemm_common.hpp"
Georgios Pinitas14613832019-03-01 19:07:11 +000031#include "gemm_hybrid.hpp"
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000032#include "gemm_hybrid_indirect.hpp"
David Manselle39334c2018-07-06 17:53:35 +010033#include "gemm_implementation.hpp"
Pablo Telloeb82fd22018-02-23 13:43:50 +000034#include "gemm_interleaved.hpp"
David Mansell0fa92b82023-10-17 13:33:24 +010035#include "gemv_pretransposed.hpp"
Pablo Telloeb82fd22018-02-23 13:43:50 +000036
Georgios Pinitas14613832019-03-01 19:07:11 +000037#include "kernels/a32_sgemm_8x6.hpp"
Francesco Petrogalli553f6952022-06-30 10:22:01 +000038#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +000039#include "kernels/a64_ffhybrid_fp16_mla_6x32.hpp"
40#include "kernels/a64_ffinterleaved_fp16_mla_8x24.hpp"
Francesco Petrogalli553f6952022-06-30 10:22:01 +000041#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000042#include "kernels/a64_hgemm_8x24.hpp"
43#include "kernels/a64_hybrid_fp16_mla_6x32.hpp"
44#include "kernels/a64_sgemm_8x12.hpp"
David Mansellaaa9da12023-03-10 13:48:50 +000045#ifdef ARM_COMPUTE_ENABLE_SME2
David Mansell0fa92b82023-10-17 13:33:24 +010046#include "kernels/sme2_gemv_fp16fp32fp16_dot_16VL.hpp"
David Mansellaaa9da12023-03-10 13:48:50 +000047#include "kernels/sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL.hpp"
48#include "kernels/sme2_interleaved_nomerge_fp16fp32fp16_mopa_2VLx2VL.hpp"
49#include "kernels/sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL.hpp"
50#endif // ARM_COMPUTE_ENABLE_SME2
Francesco Petrogalli553f6952022-06-30 10:22:01 +000051#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +000052#include "kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp"
53#include "kernels/sve_ffinterleaved_fp16_mla_8x3VL.hpp"
Francesco Petrogalli553f6952022-06-30 10:22:01 +000054#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000055#include "kernels/sve_hybrid_fp16_mla_6x4VL.hpp"
56#include "kernels/sve_interleaved_fp16_mla_8x3VL.hpp"
Pablo Telloeb82fd22018-02-23 13:43:50 +000057
Anthony Barbier5f707732018-07-03 16:22:02 +010058namespace arm_gemm {
59
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000060static const GemmImplementation<__fp16, __fp16> gemm_fp16_methods[] = {
Georgios Pinitas4ee8b152021-07-16 16:16:43 +010061#ifdef ARM_COMPUTE_ENABLE_SVE
David Mansellaaa9da12023-03-10 13:48:50 +000062#ifdef ARM_COMPUTE_ENABLE_SME2
63{
David Mansell0fa92b82023-10-17 13:33:24 +010064 GemmMethod::GEMM_HYBRID,
65 "sme2_gemv_fp16fp32fp16_dot_16VL",
66 [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; },
67 nullptr,
68 [](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp16fp32fp16_dot_16VL, __fp16, __fp16>(args); }
69},
70{
David Mansellaaa9da12023-03-10 13:48:50 +000071 GemmMethod::GEMM_INTERLEAVED,
David Mansell5c767422024-03-15 16:35:13 +000072 "sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL",
73 [](const GemmArgs &args) { return args._ci->has_sme2(); },
74 [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
75 return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
76 [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL, __fp16, __fp16, Nothing, false, false, false, true>(args); }
77},
78{
79 GemmMethod::GEMM_INTERLEAVED,
David Mansellaaa9da12023-03-10 13:48:50 +000080 "sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL",
81 [](const GemmArgs &args) { return args._ci->has_sme2(); },
82 [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
83 return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
84 [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL, __fp16, __fp16, Nothing, false, false, false, true>(args); }
85},
86{
87 GemmMethod::GEMM_INTERLEAVED,
David Mansellaaa9da12023-03-10 13:48:50 +000088 "sme2_interleaved_nomerge_fp16fp32fp16_mopa_2VLx2VL",
89 [](const GemmArgs &args) { return args._ci->has_sme2(); },
90 nullptr,
91 [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_2VLx2VL, __fp16, __fp16, Nothing, false, false, false, true>(args); }
92},
93#endif // ARM_COMPUTE_ENABLE_SME2
Georgios Pinitas4ee8b152021-07-16 16:16:43 +010094GemmImplementation<__fp16, __fp16>::with_estimate(
Georgios Pinitas14613832019-03-01 19:07:11 +000095 GemmMethod::GEMM_HYBRID,
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000096 "sve_hybrid_fp16_mla_6x4VL",
Pablo Marquez Telloa50f1932021-03-08 17:27:05 +000097 [](const GemmArgs &args) { return args._ci->has_sve(); },
Georgios Pinitas4ee8b152021-07-16 16:16:43 +010098 [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp16_mla_6x4VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
99 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp16_mla_6x4VL, __fp16, __fp16>(args); }
100),
101GemmImplementation<__fp16, __fp16>::with_estimate(
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000102 GemmMethod::GEMM_INTERLEAVED,
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000103 "sve_interleaved_fp16_mla_8x3VL",
Pablo Marquez Telloa50f1932021-03-08 17:27:05 +0000104 [](const GemmArgs &args) { return args._ci->has_sve() && (args._Ksize > 4); },
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100105 [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000106 [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>(args); }
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100107),
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000108#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000109GemmImplementation<__fp16, __fp16>::with_estimate(
110 GemmMethod::GEMM_INTERLEAVED,
111 "sve_ffinterleaved_fp16_mla_8x3VL",
112 KernelWeightFormat::VL1VL_BL16,
113 [](const GemmArgs &args) { return args._ci->has_sve(); },
114 [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
115 [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp16_mla_8x3VL, __fp16, __fp16>(args); }
116),
117GemmImplementation<__fp16, __fp16>::with_estimate(
118 GemmMethod::GEMM_HYBRID,
119 "sve_ffhybrid_fp16_mla_6x4VL",
120 KernelWeightFormat::VL1VL_BL16,
121 [](const GemmArgs &args) { return args._ci->has_sve(); },
122 [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
123 [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>(args); }
124),
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000125#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100126#endif // ARM_COMPUTE_ENABLE_SVE
127#if defined(__aarch64__)
Georgios Pinitas40943df2020-11-17 18:46:40 +0000128GemmImplementation<__fp16, __fp16>::with_estimate(
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000129 GemmMethod::GEMM_HYBRID,
130 "a64_hybrid_fp16_mla_6x32",
cfRod534fdea2020-06-25 18:12:25 +0100131 [](const GemmArgs &args) { return args._ci->has_fp16(); },
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100132 [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_fp16_mla_6x32, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000133 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_fp16_mla_6x32, __fp16, __fp16>(args); }
Georgios Pinitas40943df2020-11-17 18:46:40 +0000134),
135GemmImplementation<__fp16, __fp16>::with_estimate(
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000136 GemmMethod::GEMM_INTERLEAVED,
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000137 "a64_hgemm_8x24",
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100138 [](const GemmArgs &args) { return args._ci->has_fp16(); },
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100139 [](const GemmArgs &args) { return GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000140 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>(args); }
Georgios Pinitas40943df2020-11-17 18:46:40 +0000141),
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000142#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000143GemmImplementation<__fp16, __fp16>::with_estimate(
144 GemmMethod::GEMM_INTERLEAVED,
145 "a64_ffinterleaved_fp16_mla_8x24",
146 KernelWeightFormat::VL128_BL16,
147 [](const GemmArgs &args) { return args._ci->has_fp16(); },
148 [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp16_mla_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
149 [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp16_mla_8x24, __fp16, __fp16>(args); }
150),
151GemmImplementation<__fp16, __fp16>::with_estimate(
152 GemmMethod::GEMM_HYBRID,
153 "a64_ffhybrid_fp16_mla_6x32",
154 KernelWeightFormat::VL128_BL16,
155 [](const GemmArgs &args) { return args._ci->has_fp16(); },
156 [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
157 [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>(args); }
158),
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000159#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
Georgios Pinitas14613832019-03-01 19:07:11 +0000160{
161 GemmMethod::GEMM_INTERLEAVED,
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000162 "a64_sgemm_8x12",
Georgios Pinitas14613832019-03-01 19:07:11 +0000163 nullptr,
Georgios Pinitas40943df2020-11-17 18:46:40 +0000164 [](const GemmArgs &args) { return !args._ci->has_fp16(); },
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000165 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, __fp16, __fp16>(args); }
Georgios Pinitas14613832019-03-01 19:07:11 +0000166},
167#elif defined(__arm__)
Georgios Pinitasa41c54b2019-01-30 18:16:43 +0000168{
169 GemmMethod::GEMM_INTERLEAVED,
170 "sgemm_8x6",
Georgios Pinitas14613832019-03-01 19:07:11 +0000171 nullptr,
172 nullptr,
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100173 [](const GemmArgs &args) { return new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(args); }
Georgios Pinitasa41c54b2019-01-30 18:16:43 +0000174},
Georgios Pinitas14613832019-03-01 19:07:11 +0000175#else // not AArch64 or AArch32
176# error Unknown Architecture
Georgios Pinitasa41c54b2019-01-30 18:16:43 +0000177#endif
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000178{
179 GemmMethod::DEFAULT,
180 "",
181 nullptr,
182 nullptr,
183 nullptr,
184}
David Manselle39334c2018-07-06 17:53:35 +0100185};
186
187template<>
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000188const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp16>() {
David Manselle39334c2018-07-06 17:53:35 +0100189 return gemm_fp16_methods;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000190}
191
David Manselle39334c2018-07-06 17:53:35 +0100192/* Explicitly instantiate the external functions for these types. */
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100193template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000194template bool has_opt_gemm<__fp16, __fp16, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000195template KernelDescription get_gemm_method<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100196template std::vector<KernelDescription> get_compatible_kernels<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
David Manselle39334c2018-07-06 17:53:35 +0100197
Pablo Telloeb82fd22018-02-23 13:43:50 +0000198} // namespace arm_gemm
199
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100200#endif // defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC))