Moritz Pflanzer | 4dfc235 | 2017-08-02 14:51:36 +0100 | [diff] [blame] | 1 | /* |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 2 | * Copyright (c) 2017-2021,2024 Arm Limited. |
Moritz Pflanzer | 4dfc235 | 2017-08-02 14:51:36 +0100 | [diff] [blame] | 3 | * |
| 4 | * SPDX-License-Identifier: MIT |
| 5 | * |
| 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | * of this software and associated documentation files (the "Software"), to |
| 8 | * deal in the Software without restriction, including without limitation the |
| 9 | * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 10 | * sell copies of the Software, and to permit persons to whom the Software is |
| 11 | * furnished to do so, subject to the following conditions: |
| 12 | * |
| 13 | * The above copyright notice and this permission notice shall be included in all |
| 14 | * copies or substantial portions of the Software. |
| 15 | * |
| 16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | * SOFTWARE. |
| 23 | */ |
| 24 | #include "GEMM.h" |
| 25 | |
Michalis Spyrou | d1d7722 | 2020-04-08 14:10:15 +0100 | [diff] [blame] | 26 | #include "arm_compute/core/Helpers.h" |
Georgios Pinitas | 583137c | 2017-08-31 18:12:42 +0100 | [diff] [blame] | 27 | #include "arm_compute/core/Types.h" |
Moritz Pflanzer | 4dfc235 | 2017-08-02 14:51:36 +0100 | [diff] [blame] | 28 | |
| 29 | namespace arm_compute |
| 30 | { |
| 31 | namespace test |
| 32 | { |
| 33 | namespace validation |
| 34 | { |
| 35 | namespace reference |
| 36 | { |
| 37 | template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type> |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 38 | SimpleTensor<T> |
| 39 | gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta) |
Moritz Pflanzer | 4dfc235 | 2017-08-02 14:51:36 +0100 | [diff] [blame] | 40 | { |
| 41 | // Create reference |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 42 | SimpleTensor<T> dst{c.shape(), c.data_type(), 1}; |
Moritz Pflanzer | 4dfc235 | 2017-08-02 14:51:36 +0100 | [diff] [blame] | 43 | |
| 44 | // Compute reference |
Gian Marco Iodice | 2213d4b | 2018-04-27 10:39:06 +0100 | [diff] [blame] | 45 | const int M = a.shape().y(); |
| 46 | const int N = b.shape().x(); |
Moritz Pflanzer | 4dfc235 | 2017-08-02 14:51:36 +0100 | [diff] [blame] | 47 | const int K = a.shape().x(); |
Gian Marco Iodice | 2213d4b | 2018-04-27 10:39:06 +0100 | [diff] [blame] | 48 | const int D = a.shape().z(); // Number of matrices in a batch |
| 49 | const int W = a.shape()[3]; // Number of batched-gemm (Winograd case) |
Moritz Pflanzer | 4dfc235 | 2017-08-02 14:51:36 +0100 | [diff] [blame] | 50 | |
Gian Marco Iodice | 2213d4b | 2018-04-27 10:39:06 +0100 | [diff] [blame] | 51 | const int a_stride_z = K * M; |
| 52 | const int a_stride_w = K * M * D; |
| 53 | |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 54 | const int b_stride_z = |
| 55 | b.shape().num_dimensions() > 2 |
| 56 | ? N * K |
| 57 | : 0; // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions |
| 58 | int b_stride_w = |
| 59 | b.shape().num_dimensions() > 3 |
| 60 | ? K * N * D |
| 61 | : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions |
Gian Marco Iodice | 37a4611 | 2021-08-04 15:22:28 +0100 | [diff] [blame] | 62 | |
| 63 | // Note: There are 3 gemm types: batched-gemm, multi-gemm, and batched of multi-gemms. The third dimension of tensor b is overloaded when tensor b has exactly 3 dimensions: |
| 64 | // it can be either number of batches or multis. Batched-GEMM computation is detected only when the third dimension of "a" and "c" tensors is 1 and the number of dimensions is 4 |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 65 | const bool is_batched_gemm = b.shape().num_dimensions() == 3 && a.shape().num_dimensions() == 4 && |
| 66 | c.shape().num_dimensions() == 4 && a.shape()[2] == 1 && c.shape()[2] == 1; |
Gian Marco Iodice | 37a4611 | 2021-08-04 15:22:28 +0100 | [diff] [blame] | 67 | |
| 68 | // Batched-GEMM |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 69 | if (is_batched_gemm) |
Gian Marco Iodice | 37a4611 | 2021-08-04 15:22:28 +0100 | [diff] [blame] | 70 | { |
| 71 | b_stride_w = b_stride_z; |
| 72 | } |
Gian Marco Iodice | 2213d4b | 2018-04-27 10:39:06 +0100 | [diff] [blame] | 73 | |
| 74 | const int c_stride_z = N * M; |
| 75 | const int c_stride_w = N * M * D; |
| 76 | |
Gian Marco Iodice | 37a4611 | 2021-08-04 15:22:28 +0100 | [diff] [blame] | 77 | #if defined(_OPENMP) && !(defined(__arm__) && defined(__ANDROID__)) |
Michalis Spyrou | d1d7722 | 2020-04-08 14:10:15 +0100 | [diff] [blame] | 78 | #pragma omp parallel for collapse(2) |
| 79 | #endif /* _OPENMP */ |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 80 | for (int w = 0; w < W; ++w) |
Moritz Pflanzer | 4dfc235 | 2017-08-02 14:51:36 +0100 | [diff] [blame] | 81 | { |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 82 | for (int depth = 0; depth < D; ++depth) |
Moritz Pflanzer | 4dfc235 | 2017-08-02 14:51:36 +0100 | [diff] [blame] | 83 | { |
Gian Marco Iodice | 2213d4b | 2018-04-27 10:39:06 +0100 | [diff] [blame] | 84 | const int base_addr_a = depth * a_stride_z + w * a_stride_w; |
| 85 | const int base_addr_b = depth * b_stride_z + w * b_stride_w; |
| 86 | const int base_addr_c = depth * c_stride_z + w * c_stride_w; |
Moritz Pflanzer | 4dfc235 | 2017-08-02 14:51:36 +0100 | [diff] [blame] | 87 | |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 88 | for (int row = 0; row < M; ++row) |
Moritz Pflanzer | 4dfc235 | 2017-08-02 14:51:36 +0100 | [diff] [blame] | 89 | { |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 90 | for (int col = 0; col < N; ++col) |
Gian Marco Iodice | 2213d4b | 2018-04-27 10:39:06 +0100 | [diff] [blame] | 91 | { |
| 92 | T acc(0); |
Moritz Pflanzer | 4dfc235 | 2017-08-02 14:51:36 +0100 | [diff] [blame] | 93 | |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 94 | for (int k = 0; k < K; ++k) |
Gian Marco Iodice | 2213d4b | 2018-04-27 10:39:06 +0100 | [diff] [blame] | 95 | { |
| 96 | acc += a[base_addr_a + k + row * K] * b[base_addr_b + col + k * N]; |
| 97 | } |
| 98 | |
| 99 | // Finalize the result: alpha * A * B + beta * C |
| 100 | dst[base_addr_c + col + row * N] = alpha * acc + beta * c[base_addr_c + col + row * N]; |
| 101 | } |
| 102 | } |
Moritz Pflanzer | 4dfc235 | 2017-08-02 14:51:36 +0100 | [diff] [blame] | 103 | } |
| 104 | } |
| 105 | |
| 106 | return dst; |
| 107 | } |
| 108 | |
Gian Marco Iodice | 0c17aa2 | 2019-09-27 09:23:15 +0100 | [diff] [blame] | 109 | template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type> |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 110 | SimpleTensor<T> gemm_mixed_precision( |
| 111 | const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta) |
Gian Marco Iodice | 0c17aa2 | 2019-09-27 09:23:15 +0100 | [diff] [blame] | 112 | { |
| 113 | // GEMM mixed-precision combines F32 accumulators with F16 multiplications |
| 114 | // Create reference |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 115 | SimpleTensor<T> dst{c.shape(), c.data_type(), 1}; |
Gian Marco Iodice | 0c17aa2 | 2019-09-27 09:23:15 +0100 | [diff] [blame] | 116 | |
| 117 | // Compute reference |
| 118 | const int M = a.shape().y(); |
| 119 | const int N = b.shape().x(); |
| 120 | const int K = a.shape().x(); |
| 121 | const int D = a.shape().z(); // Number of matrices in a batch |
| 122 | const int W = a.shape()[3]; // Number of batched-gemm (Winograd case) |
| 123 | |
| 124 | const int a_stride_z = K * M; |
| 125 | const int a_stride_w = K * M * D; |
| 126 | |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 127 | const int b_stride_z = |
| 128 | b.shape().num_dimensions() > 2 |
| 129 | ? N * K |
| 130 | : 0; // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions |
| 131 | int b_stride_w = |
| 132 | b.shape().num_dimensions() > 3 |
| 133 | ? K * N * D |
| 134 | : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions |
Gian Marco Iodice | 37a4611 | 2021-08-04 15:22:28 +0100 | [diff] [blame] | 135 | |
| 136 | // Note: There are 3 gemm types: batched-gemm, multi-gemm, and batched of multi-gemms. The third dimension of tensor b is overloaded when tensor b has exactly 3 dimensions: |
| 137 | // it can be either number of batches or multis. Batched-GEMM computation is detected only when the third dimension of "a" and "c" tensors is 1 and the number of dimensions is 4 |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 138 | const bool is_batched_gemm = b.shape().num_dimensions() == 3 && a.shape().num_dimensions() == 4 && |
| 139 | c.shape().num_dimensions() == 4 && a.shape()[2] == 1 && c.shape()[2] == 1; |
Gian Marco Iodice | 37a4611 | 2021-08-04 15:22:28 +0100 | [diff] [blame] | 140 | |
| 141 | // Batched-GEMM |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 142 | if (is_batched_gemm) |
Gian Marco Iodice | 37a4611 | 2021-08-04 15:22:28 +0100 | [diff] [blame] | 143 | { |
| 144 | b_stride_w = b_stride_z; |
| 145 | } |
Gian Marco Iodice | 0c17aa2 | 2019-09-27 09:23:15 +0100 | [diff] [blame] | 146 | |
| 147 | const int c_stride_z = N * M; |
| 148 | const int c_stride_w = N * M * D; |
| 149 | |
Gian Marco Iodice | 37a4611 | 2021-08-04 15:22:28 +0100 | [diff] [blame] | 150 | #if defined(_OPENMP) && !(defined(__arm__) && defined(__ANDROID__)) |
Michalis Spyrou | d1d7722 | 2020-04-08 14:10:15 +0100 | [diff] [blame] | 151 | #pragma omp parallel for collapse(2) |
| 152 | #endif /* _OPENMP */ |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 153 | for (int w = 0; w < W; ++w) |
Gian Marco Iodice | 0c17aa2 | 2019-09-27 09:23:15 +0100 | [diff] [blame] | 154 | { |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 155 | for (int depth = 0; depth < D; ++depth) |
Gian Marco Iodice | 0c17aa2 | 2019-09-27 09:23:15 +0100 | [diff] [blame] | 156 | { |
| 157 | const int base_addr_a = depth * a_stride_z + w * a_stride_w; |
| 158 | const int base_addr_b = depth * b_stride_z + w * b_stride_w; |
| 159 | const int base_addr_c = depth * c_stride_z + w * c_stride_w; |
| 160 | |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 161 | for (int row = 0; row < M; ++row) |
Gian Marco Iodice | 0c17aa2 | 2019-09-27 09:23:15 +0100 | [diff] [blame] | 162 | { |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 163 | for (int col = 0; col < N; ++col) |
Gian Marco Iodice | 0c17aa2 | 2019-09-27 09:23:15 +0100 | [diff] [blame] | 164 | { |
| 165 | float acc(0); |
| 166 | |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 167 | for (int k = 0; k < K; ++k) |
Gian Marco Iodice | 0c17aa2 | 2019-09-27 09:23:15 +0100 | [diff] [blame] | 168 | { |
| 169 | acc += static_cast<float>(a[base_addr_a + k + row * K] * b[base_addr_b + col + k * N]); |
| 170 | } |
| 171 | |
| 172 | // Finalize the result: alpha * A * B + beta * C |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 173 | dst[base_addr_c + col + row * N] = |
| 174 | static_cast<T>(alpha * acc + beta * c[base_addr_c + col + row * N]); |
Gian Marco Iodice | 0c17aa2 | 2019-09-27 09:23:15 +0100 | [diff] [blame] | 175 | } |
| 176 | } |
| 177 | } |
| 178 | } |
| 179 | |
| 180 | return dst; |
| 181 | } |
| 182 | |
Renato Arantes | 36a75da | 2024-01-26 17:31:18 +0000 | [diff] [blame^] | 183 | template SimpleTensor<float> |
| 184 | gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta); |
| 185 | template SimpleTensor<bfloat16> gemm(const SimpleTensor<bfloat16> &a, |
| 186 | const SimpleTensor<bfloat16> &b, |
| 187 | const SimpleTensor<bfloat16> &c, |
| 188 | float alpha, |
| 189 | float beta); |
| 190 | template SimpleTensor<half> |
| 191 | gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta); |
| 192 | template SimpleTensor<half> gemm_mixed_precision( |
| 193 | const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta); |
Moritz Pflanzer | 4dfc235 | 2017-08-02 14:51:36 +0100 | [diff] [blame] | 194 | } // namespace reference |
| 195 | } // namespace validation |
| 196 | } // namespace test |
| 197 | } // namespace arm_compute |