blob: d952140959e366c750bc662b7a282f9528f1c627 [file] [log] [blame]
David Manselle39334c2018-07-06 17:53:35 +01001/*
Georgios Pinitas7cd26d42019-01-09 18:35:17 +00002 * Copyright (c) 2018-2019 ARM Limited.
David Manselle39334c2018-07-06 17:53:35 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000025#include <arm_gemm.hpp>
26
27#include <functional>
David Manselle39334c2018-07-06 17:53:35 +010028
29namespace arm_gemm {
30
31template<typename Top, typename Tret>
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000032struct GemmImplementation {
33 const GemmMethod method;
34 const char * name;
35 std::function<bool(const GemmArgs<Tret> &)> is_supported;
36 std::function<bool(const GemmArgs<Tret> &)> is_recommended;
37 std::function<GemmCommon<Top, Tret> *(const GemmArgs<Tret> &)> instantiate;
David Manselle39334c2018-07-06 17:53:35 +010038};
39
40/* "Master" function implemented for each valid combination of types.
41 * Returns a list of GEMM implementation descriptors for processing by the
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000042 * other functions, terminated by an implementation with
43 * method==GemmMethod::DEFAULT. */
David Manselle39334c2018-07-06 17:53:35 +010044template<typename Top, typename Tret>
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000045const GemmImplementation<Top, Tret> *gemm_implementation_list();
David Manselle39334c2018-07-06 17:53:35 +010046
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000047/*
48 * Select a GEMM implementation for the given arguments.
49 *
50 * The logic here returns the first method on the list which supports the
51 * requested problem parameters, matches the provided filters (method and/or
52 * name string match) and recommends itself.
53 *
54 * If there is no such method, it will return the first method which
55 * supports the requested parameters and passes the filters, regardless of
56 * recommendation.
57 *
58 * If no method supports the requested parameters and passes the filters,
59 * this function returns false and doesn't touch the provided pointer
60 * reference.
61 */
David Manselle39334c2018-07-06 17:53:35 +010062template<typename Top, typename Tret>
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000063bool find_implementation(const GemmArgs<Tret> &args, const GemmImplementation<Top, Tret> * &impl) {
David Manselle39334c2018-07-06 17:53:35 +010064 auto gemms = gemm_implementation_list<Top, Tret>();
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000065 const GemmConfig *cfg = args._cfg;
David Manselle39334c2018-07-06 17:53:35 +010066
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000067 const GemmImplementation<Top, Tret> *saved_impl = nullptr;
68
69 for (auto i = gemms; i->method != GemmMethod::DEFAULT; i++) {
David Manselle39334c2018-07-06 17:53:35 +010070 /* Skip if this implementation doesn't support these args. */
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000071 if (i->is_supported != nullptr && !i->is_supported(args)) {
David Manselle39334c2018-07-06 17:53:35 +010072 continue;
73 }
74
75 /* Skip if a specific method is requested and this is a different one. */
76 if (cfg && cfg->method != GemmMethod::DEFAULT && i->method != cfg->method) {
77 continue;
78 }
79
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000080 /* Skip if a filter is to be applied and it doesn't match. */
81 if (cfg && cfg->filter != "" && !strstr(i->name, cfg->filter.c_str())) {
David Manselle39334c2018-07-06 17:53:35 +010082 continue;
83 }
84
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000085 /* At this point, if we don't have a saved implementation, save this
86 * one. This is so that we always return something if a filter
87 * matches, even if it doesn't recommend itself.
88 */
89 if (saved_impl == nullptr) {
90 saved_impl=i;
91 }
92
93 /* Check that this method recommends itself. */
94 if (i->is_recommended != nullptr && !i->is_recommended(args)) {
95 continue;
96 }
97
98 impl=i;
99
100 return true;
David Manselle39334c2018-07-06 17:53:35 +0100101 }
102
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000103 /* We didn't find an option matching the filters that recommended
104 * itself. But if we found something earlier that matched the filters
105 * but wasn't recommended, return it here. */
106 if (saved_impl != nullptr) {
107 impl = saved_impl;
David Manselle39334c2018-07-06 17:53:35 +0100108 return true;
109 }
110
111 return false;
112}
113
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000114template<typename Top, typename Tret>
Georgios Pinitas14613832019-03-01 19:07:11 +0000115std::vector<KernelDescription> get_compatible_kernels(const GemmArgs<Tret> &args) {
116 std::vector<KernelDescription> res;
117
118 /* Find out what the default implementation in so we can set the flag accordingly later. */
119 const GemmImplementation<Top, Tret> *default_impl;
120 find_implementation(args, default_impl);
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000121
122 auto gemms = gemm_implementation_list<Top, Tret>();
123
124 for (auto i = gemms; i->method != GemmMethod::DEFAULT; i++) {
125 /* Check that this implementation supports the presented problem. */
126 if (i->is_supported != nullptr && !i->is_supported(args)) {
127 continue;
128 }
129
Georgios Pinitas14613832019-03-01 19:07:11 +0000130 res.push_back(KernelDescription(i->method, i->name, i==default_impl));
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000131 }
132
133 return res;
134}
135
136template<typename Top, typename Tret>
137UniqueGemmCommon<Top, Tret> gemm(const GemmArgs<Tret> &args) {
138 const GemmImplementation<Top, Tret> *impl;
139
140 if (find_implementation<Top, Tret>(args, impl)) {
141 return UniqueGemmCommon<Top, Tret>(impl->instantiate(args));
142 }
143
144 return UniqueGemmCommon<Top, Tret>(nullptr);
145}
146
147template<typename Top, typename Tret>
148KernelDescription get_gemm_method(const GemmArgs<Tret> &args) {
149 const GemmImplementation<Top, Tret> *impl;
150
151 if (find_implementation<Top, Tret>(args, impl)) {
152 return KernelDescription(impl->method, impl->name);
153 }
154
155 /* This shouldn't happen - there should always be at least one valid implementation. */
156 return KernelDescription();
157}
158
159template<typename Top, typename Tret>
160bool method_is_compatible(GemmMethod method, const GemmArgs<Tret> &args) {
161 /* Determine if the method is valid by attempting to obtain an implementation specifying this method. */
162 GemmConfig cfg(method);
163 GemmArgs<Tret> myargs = args;
164
165 myargs._cfg = &cfg;
166
167 const GemmImplementation<Top, Tret> *impl;
168
169 return find_implementation<Top, Tret>(myargs, impl);
170}
171
172} // namespace arm_gemm