blob: c726d7b0aa48284c883ff8b5b6092e6c26fe332a [file] [log] [blame]
David Manselle39334c2018-07-06 17:53:35 +01001/*
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +01002 * Copyright (c) 2018-2020 Arm Limited.
David Manselle39334c2018-07-06 17:53:35 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010025#include "arm_gemm.hpp"
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000026
27#include <functional>
David Manselle39334c2018-07-06 17:53:35 +010028
29namespace arm_gemm {
30
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010031/* Structure describing an implementation. For each supported combination
32 * of types, a static list of these structures is built up to describe the
33 * implementations available.
34 */
35template<typename Top, typename Tret, class OutputStage = Nothing>
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000036struct GemmImplementation {
Georgios Pinitas48b3ef82019-10-14 19:03:09 +010037 const GemmMethod method;
38 const char * name;
39 std::function<bool(const GemmArgs &, const OutputStage &)> is_supported;
40 std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended;
41 std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010042
Georgios Pinitas48b3ef82019-10-14 19:03:09 +010043 bool do_is_supported(const GemmArgs &args, const OutputStage &os) const {
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010044 if (is_supported != nullptr) {
45 return is_supported(args, os);
46 } else {
47 return true;
48 }
49 }
50
Georgios Pinitas48b3ef82019-10-14 19:03:09 +010051 bool do_is_recommended(const GemmArgs &args, const OutputStage &os) const {
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010052 if (is_recommended != nullptr) {
53 return is_recommended(args, os);
54 } else {
55 return true;
56 }
57 }
58
Georgios Pinitas48b3ef82019-10-14 19:03:09 +010059 GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const OutputStage &os) const {
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010060 return instantiate(args, os);
61 }
62};
63
64/* Slightly different version of above for straightforward GEMMs with no
65 * output stage, so the std::functions there don't have to deal with the
66 * unnecessary second argument. */
67template<typename Top, typename Tret>
68struct GemmImplementation<Top, Tret, Nothing> {
Georgios Pinitas48b3ef82019-10-14 19:03:09 +010069 const GemmMethod method;
70 const char * name;
71 std::function<bool(const GemmArgs &)> is_supported;
72 std::function<bool(const GemmArgs &)> is_recommended;
73 std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010074
Georgios Pinitas48b3ef82019-10-14 19:03:09 +010075 bool do_is_supported(const GemmArgs &args, const Nothing &) const {
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010076 if (is_supported != nullptr) {
77 return is_supported(args);
78 } else {
79 return true;
80 }
81 }
82
Georgios Pinitas48b3ef82019-10-14 19:03:09 +010083 bool do_is_recommended(const GemmArgs &args, const Nothing &) const {
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010084 if (is_recommended != nullptr) {
85 return is_recommended(args);
86 } else {
87 return true;
88 }
89 }
90
Georgios Pinitas48b3ef82019-10-14 19:03:09 +010091 GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const Nothing &) const {
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010092 return instantiate(args);
93 }
David Manselle39334c2018-07-06 17:53:35 +010094};
95
96/* "Master" function implemented for each valid combination of types.
97 * Returns a list of GEMM implementation descriptors for processing by the
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000098 * other functions, terminated by an implementation with
99 * method==GemmMethod::DEFAULT. */
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100100template<typename Top, typename Tret, class OutputStage = Nothing>
101const GemmImplementation<Top, Tret, OutputStage> *gemm_implementation_list();
David Manselle39334c2018-07-06 17:53:35 +0100102
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000103/*
104 * Select a GEMM implementation for the given arguments.
105 *
106 * The logic here returns the first method on the list which supports the
107 * requested problem parameters, matches the provided filters (method and/or
108 * name string match) and recommends itself.
109 *
110 * If there is no such method, it will return the first method which
111 * supports the requested parameters and passes the filters, regardless of
112 * recommendation.
113 *
114 * If no method supports the requested parameters and passes the filters,
115 * this function returns false and doesn't touch the provided pointer
116 * reference.
117 */
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100118template<typename Top, typename Tret, class OutputStage>
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100119bool find_implementation(const GemmArgs &args, const OutputStage &os, const GemmImplementation<Top, Tret, OutputStage> * &impl) {
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100120 auto gemms = gemm_implementation_list<Top, Tret, OutputStage>();
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000121 const GemmConfig *cfg = args._cfg;
David Manselle39334c2018-07-06 17:53:35 +0100122
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100123 const GemmImplementation<Top, Tret, OutputStage> *saved_impl = nullptr;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000124
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100125 for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) {
David Manselle39334c2018-07-06 17:53:35 +0100126 /* Skip if this implementation doesn't support these args. */
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100127 if (!i->do_is_supported(args, os)) {
David Manselle39334c2018-07-06 17:53:35 +0100128 continue;
129 }
130
131 /* Skip if a specific method is requested and this is a different one. */
132 if (cfg && cfg->method != GemmMethod::DEFAULT && i->method != cfg->method) {
133 continue;
134 }
135
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000136 /* Skip if a filter is to be applied and it doesn't match. */
137 if (cfg && cfg->filter != "" && !strstr(i->name, cfg->filter.c_str())) {
David Manselle39334c2018-07-06 17:53:35 +0100138 continue;
139 }
140
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000141 /* At this point, if we don't have a saved implementation, save this
142 * one. This is so that we always return something if a filter
143 * matches, even if it doesn't recommend itself.
144 */
145 if (saved_impl == nullptr) {
146 saved_impl=i;
147 }
148
149 /* Check that this method recommends itself. */
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100150 if (!i->do_is_recommended(args, os)) {
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000151 continue;
152 }
153
154 impl=i;
155
156 return true;
David Manselle39334c2018-07-06 17:53:35 +0100157 }
158
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000159 /* We didn't find an option matching the filters that recommended
160 * itself. But if we found something earlier that matched the filters
161 * but wasn't recommended, return it here. */
162 if (saved_impl != nullptr) {
163 impl = saved_impl;
David Manselle39334c2018-07-06 17:53:35 +0100164 return true;
165 }
166
167 return false;
168}
169
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100170template<typename Top, typename Tret, class OutputStage>
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100171std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage &os) {
Georgios Pinitas14613832019-03-01 19:07:11 +0000172 std::vector<KernelDescription> res;
173
174 /* Find out what the default implementation in so we can set the flag accordingly later. */
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100175 const GemmImplementation<Top, Tret, OutputStage> *default_impl;
176 find_implementation(args, os, default_impl);
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000177
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100178 auto gemms = gemm_implementation_list<Top, Tret, OutputStage>();
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000179
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100180 for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) {
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000181 /* Check that this implementation supports the presented problem. */
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100182 if (!i->do_is_supported(args, os)) {
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000183 continue;
184 }
185
Georgios Pinitas14613832019-03-01 19:07:11 +0000186 res.push_back(KernelDescription(i->method, i->name, i==default_impl));
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000187 }
188
189 return res;
190}
191
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100192template<typename Top, typename Tret, class OutputStage>
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100193UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage &os) {
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100194 const GemmImplementation<Top, Tret, OutputStage> *impl;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000195
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100196 if (find_implementation<Top, Tret, OutputStage>(args, os, impl)) {
197 return UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os));
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000198 }
199
200 return UniqueGemmCommon<Top, Tret>(nullptr);
201}
202
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100203template<typename Top, typename Tret, class OutputStage>
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100204KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage &os) {
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100205 const GemmImplementation<Top, Tret, OutputStage> *impl;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000206
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100207 if (find_implementation<Top, Tret>(args, os, impl)) {
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000208 return KernelDescription(impl->method, impl->name);
209 }
210
211 /* This shouldn't happen - there should always be at least one valid implementation. */
212 return KernelDescription();
213}
214
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100215} // namespace arm_gemm