blob: 10fee472f4b83ba95ca5297eb158e5151778f7c5 [file] [log] [blame]
Georgios Pinitasc7b183a2020-03-06 18:12:09 +00001/*
2 * Copyright (c) 2017-2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_gemm.hpp"
25#include "gemm_common.hpp"
26#include "gemm_hybrid.hpp"
27#include "gemm_implementation.hpp"
28#include "gemm_interleaved.hpp"
29#include "gemm_native.hpp"
30#include "gemv_batched.hpp"
31#include "gemv_native_transposed.hpp"
32#include "gemv_pretransposed.hpp"
33
34#include "kernels/a64_interleaved_bf16fp32_dot_12x8.hpp"
35#include "kernels/a64_interleaved_bf16fp32_mmla_12x8.hpp"
36#include "kernels/a64_sgemm_12x8.hpp"
37#include "kernels/a32_sgemm_8x6.hpp"
38#include "kernels/sve_interleaved_bf16fp32_dot_3VLx8.hpp"
39#include "kernels/sve_interleaved_bf16fp32_mmla_3VLx8.hpp"
40#include "kernels/sve_native_bf16fp32_dot_4VLx4.hpp"
41#include "kernels/sve_hybrid_bf16fp32_dot_4VLx4.hpp"
42#include "kernels/sve_hybrid_bf16fp32_mmla_4VLx4.hpp"
43#include "kernels/sve_hybrid_bf16fp32_mmla_6VLx2.hpp"
44#include "kernels/sve_hybrid_bf16fp32_mmla_8VLx2.hpp"
45
46#include "bfloat.hpp"
47
48namespace arm_gemm {
49
50
51static const GemmImplementation<bfloat16, float> gemm_bf16_methods[] =
52{
53#ifdef V8P6_BF
54# ifdef __ARM_FEATURE_SVE
55{
56 GemmMethod::GEMM_HYBRID,
57 "hybrid_bf16fp32_mmla_6VLx2",
58 [](const GemmArgs &args) { return (args._Ksize>=8 && !args._trA && args._pretransposed_hint); },
59 [](const GemmArgs &args) { return ((args._Msize <= 4) && (args._Nsize <= hybrid_bf16fp32_mmla_6VLx2::out_width())); },
60 [](const GemmArgs &args) { return new GemmHybrid<hybrid_bf16fp32_mmla_6VLx2, bfloat16, float>(args); }
61},
62{
63 GemmMethod::GEMM_HYBRID,
64 "hybrid_bf16fp32_mmla_8VLx2",
65 [](const GemmArgs &args) { return (args._Ksize>=8 && !args._trA && args._pretransposed_hint); },
66 [](const GemmArgs &args) { return (args._Msize <= 4); },
67 [](const GemmArgs &args) { return new GemmHybrid<hybrid_bf16fp32_mmla_8VLx2, bfloat16, float>(args); }
68},
69{
70 GemmMethod::GEMM_HYBRID,
71 "hybrid_bf16fp32_mmla_4VLx4",
72 [](const GemmArgs &args) { return (args._Ksize>=8 && !args._trA && args._pretransposed_hint); },
73 [](const GemmArgs &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)); },
74 [](const GemmArgs &args) { return new GemmHybrid<hybrid_bf16fp32_mmla_4VLx4, bfloat16, float>(args); }
75},
76{
77 GemmMethod::GEMM_HYBRID,
78 "hybrid_bf16fp32_dot_4VLx4",
79 [](const GemmArgs &args) { return (args._Ksize>=8 && !args._trA && args._pretransposed_hint); },
80 [](const GemmArgs &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)); },
81 [](const GemmArgs &args) { return new GemmHybrid<hybrid_bf16fp32_dot_4VLx4, bfloat16, float>(args); }
82},
83{ // gemm_bf16_native
84 GemmMethod::GEMM_NATIVE,
85 "native_bf16fp32_dot_4VLx4",
86 [](const GemmArgs &args) { return (args._Ksize>=8 && !args._trA && !args._trB); },
87 [](const GemmArgs &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)); },
88 [](const GemmArgs &args) { return new GemmNative<native_bf16fp32_dot_4VLx4, bfloat16, float>(args); }
89},
90{ // gemm_bf16_interleaved
91 GemmMethod::GEMM_INTERLEAVED,
92 "interleaved_bf16fp32_mmla_3VLx8",
93 [](const GemmArgs &args) { return (args._Ksize>4); },
94 nullptr,
95 [](const GemmArgs &args) { return new GemmInterleaved<interleaved_bf16fp32_mmla_3VLx8, bfloat16, float>(args); }
96},
97{ // gemm_bf16_interleaved
98 GemmMethod::GEMM_INTERLEAVED,
99 "interleaved_bf16fp32_dot_3VLx8",
100 [](const GemmArgs &args) { return (args._Ksize>2); },
101 nullptr,
102 [](const GemmArgs &args) { return new GemmInterleaved<interleaved_bf16fp32_dot_3VLx8, bfloat16, float>(args); }
103},
104# endif // SVE
105{ // gemm_bf16_interleaved
106 GemmMethod::GEMM_INTERLEAVED,
107 "interleaved_bf16fp32_mmla_12x8",
108 [](const GemmArgs &args) { return (args._Ksize>4); },
109 nullptr,
110 [](const GemmArgs &args) { return new GemmInterleaved<interleaved_bf16fp32_mmla_12x8, bfloat16, float>(args); }
111},
112{ // gemm_bf16_interleaved
113 GemmMethod::GEMM_INTERLEAVED,
114 "interleaved_bf16fp32_dot_12x8",
115 [](const GemmArgs &args) { return (args._Ksize>2); },
116 nullptr,
117 [](const GemmArgs &args) { return new GemmInterleaved<interleaved_bf16fp32_dot_12x8, bfloat16, float>(args); }
118},
119#endif // V8P6_BF
120#ifdef __aarch64__
121{
122 GemmMethod::GEMM_INTERLEAVED,
123 "sgemm_12x8",
124 nullptr,
125 nullptr,
126 [](const GemmArgs &args) { return new GemmInterleaved<sgemm_12x8, bfloat16, float>(args); }
127},
128#elif defined(__arm__)
129{
130 GemmMethod::GEMM_INTERLEAVED,
131 "sgemm_8x6",
132 nullptr,
133 nullptr,
134 [](const GemmArgs &args) { return new GemmInterleaved<sgemm_8x6, bfloat16, float>(args); }
135},
136#else
137# error "Unknown Architecture"
138#endif
139{
140 GemmMethod::DEFAULT,
141 "",
142 nullptr,
143 nullptr,
144 nullptr
145}
146};
147
148template<>
149const GemmImplementation<bfloat16, float> *gemm_implementation_list<bfloat16, float>() {
150 return gemm_bf16_methods;
151}
152
153/* Explicitly instantiate the external functions for these types. */
154template UniqueGemmCommon<bfloat16, float> gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
155template KernelDescription get_gemm_method<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
156template std::vector<KernelDescription> get_compatible_kernels<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
157
158} // namespace arm_gemm