blob: c093761614ccd9b669cfcad32baa45ab4af5c59a [file] [log] [blame]
Pablo Telloeb82fd22018-02-23 13:43:50 +00001/*
2 * Copyright (c) 2017-2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_gemm.hpp"
25#include "gemm_common.hpp"
26#include "gemm_interleaved.hpp"
27#include "gemm_native.hpp"
David Mansellce8f6052018-05-17 18:51:26 +010028#include "gemv_batched.hpp"
Pablo Telloeb82fd22018-02-23 13:43:50 +000029#include "gemv_native_transposed.hpp"
30#include "gemv_pretransposed.hpp"
31
32#include "kernels/a32_sgemm_8x6.hpp"
33#include "kernels/a64_sgemm_12x8.hpp"
34#include "kernels/a64_sgemm_native_16x4.hpp"
35#include "kernels/a64_sgemv_pretransposed.hpp"
36#include "kernels/a64_sgemv_trans.hpp"
37
38namespace arm_gemm
39{
40template <>
41UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
Michalis Spyroue7e96e02018-04-13 13:44:10 +010042 const unsigned int nbatches, const unsigned int nmulti,
Pablo Telloeb82fd22018-02-23 13:43:50 +000043 const bool trA, const bool trB, const float alpha, const float beta,
David Mansellce8f6052018-05-17 18:51:26 +010044 const int maxthreads, const bool pretransposed_hint) {
45 /* Handle "batched GEMV" */
46 if (M==1 && nbatches>1) {
47 return UniqueGemmCommon<float, float> (new GemvBatched<float, float>(ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
Michalis Spyroue7e96e02018-04-13 13:44:10 +010048 }
Pablo Telloeb82fd22018-02-23 13:43:50 +000049#ifdef __aarch64__
50 /* Cases in priority order */
Michalis Spyroue7e96e02018-04-13 13:44:10 +010051 /* GemvPretransposed: requires M=1, alpha=1, and transposed hint set. nbatches must be 1 or we would have returned above so don't test. */
Pablo Telloeb82fd22018-02-23 13:43:50 +000052 if(M == 1 && alpha == 1.0f && pretransposed_hint)
53 {
Michalis Spyroue7e96e02018-04-13 13:44:10 +010054 return UniqueGemmCommon<float, float>(new GemvPretransposed<sgemv_pretransposed, float, float>(&ci, N, K, nmulti, trB, beta));
Pablo Telloeb82fd22018-02-23 13:43:50 +000055 }
56
Pablo Tello99ef8402018-03-20 16:46:55 +000057 /* GemvNativeTransposed: requires M=1, no trA or trB, doesn't handle alpha */
58 if(M == 1 && alpha == 1.0f && !trA && !trB)
Pablo Telloeb82fd22018-02-23 13:43:50 +000059 {
Michalis Spyroue7e96e02018-04-13 13:44:10 +010060 return UniqueGemmCommon<float, float>(new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, nmulti, beta));
Pablo Telloeb82fd22018-02-23 13:43:50 +000061 }
62
David Mansellce8f6052018-05-17 18:51:26 +010063 /* Native GEMM: requires K at least 4, N a multiple of 16, doesn't
64 * handle alpha or transpose. Use for small N/K, or if the blocked GEMM
65 * won't thread properly. */
66 if ((K >= 4) && ((N % 16) == 0) && alpha==1.0f && !trA && !trB &&
67 ((K <= 128 && N <= 128) || (nmulti > 1 && (M/maxthreads) < 8))) {
68 return UniqueGemmCommon<float, float> (new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, nbatches, nmulti, beta));
Pablo Telloeb82fd22018-02-23 13:43:50 +000069 }
70
71 /* Blocked GEMM, handles all cases. */
Michalis Spyroue7e96e02018-04-13 13:44:10 +010072 return UniqueGemmCommon<float, float>(new GemmInterleaved<sgemm_12x8, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
Pablo Telloeb82fd22018-02-23 13:43:50 +000073#else
Michalis Spyroue7e96e02018-04-13 13:44:10 +010074 return UniqueGemmCommon<float, float>(new GemmInterleaved<sgemm_8x6, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
Pablo Telloeb82fd22018-02-23 13:43:50 +000075#endif
76}
77
78// Instantiate static class variables.
79#ifdef __aarch64__
80const int sgemm_12x8::out_width;
81const int sgemm_12x8::out_height;
82
83const int sgemm_native_16x4::out_width;
84const int sgemm_native_16x4::out_height;
85#else
86const int sgemm_8x6::out_width;
87const int sgemm_8x6::out_height;
88#endif
89
90} // namespace arm_gemm