blob: f9dcfcbdd0de94d574b572cd2aa30f7d460fd32d [file] [log] [blame]
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +01001/*
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +01002 * Copyright (c) 2017-2018 ARM Limited.
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "GEMM.h"
25
Georgios Pinitas583137c2017-08-31 18:12:42 +010026#include "arm_compute/core/Types.h"
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010027#include "tests/validation/FixedPoint.h"
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010028
29namespace arm_compute
30{
31namespace test
32{
33namespace validation
34{
35namespace reference
36{
37template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
38SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta)
39{
40 // Create reference
41 SimpleTensor<T> dst{ c.shape(), c.data_type(), 1, c.fixed_point_position() };
42
43 // Compute reference
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +010044 const int M = a.shape().y();
45 const int N = b.shape().x();
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010046 const int K = a.shape().x();
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +010047 const int D = a.shape().z(); // Number of matrices in a batch
48 const int W = a.shape()[3]; // Number of batched-gemm (Winograd case)
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010049
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +010050 const int a_stride_z = K * M;
51 const int a_stride_w = K * M * D;
52
53 const int b_stride_z = b.shape().num_dimensions() > 2 ? N * K : 0; // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions
54 const int b_stride_w = b.shape().num_dimensions() > 3 ? K * N * D : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions
55
56 const int c_stride_z = N * M;
57 const int c_stride_w = N * M * D;
58
59 for(int w = 0; w < W; ++w)
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010060 {
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +010061 for(int depth = 0; depth < D; ++depth)
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010062 {
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +010063 const int base_addr_a = depth * a_stride_z + w * a_stride_w;
64 const int base_addr_b = depth * b_stride_z + w * b_stride_w;
65 const int base_addr_c = depth * c_stride_z + w * c_stride_w;
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010066
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +010067 for(int row = 0; row < M; ++row)
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010068 {
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +010069 for(int col = 0; col < N; ++col)
70 {
71 T acc(0);
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010072
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +010073 for(int k = 0; k < K; ++k)
74 {
75 acc += a[base_addr_a + k + row * K] * b[base_addr_b + col + k * N];
76 }
77
78 // Finalize the result: alpha * A * B + beta * C
79 dst[base_addr_c + col + row * N] = alpha * acc + beta * c[base_addr_c + col + row * N];
80 }
81 }
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010082 }
83 }
84
85 return dst;
86}
87
88template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type>
89SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta)
90{
91 using namespace fixed_point_arithmetic;
92
93 // Create reference
94 SimpleTensor<T> dst{ c.shape(), c.data_type(), 1, c.fixed_point_position() };
95
96 // Compute reference
97 using promoted_type = fixed_point_arithmetic::traits::promote_t<T>;
98
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +010099 const int M = dst.shape().y();
100 const int N = dst.shape().x();
101 const int K = a.shape().x();
102 const int D = a.shape().z(); // Number of matrices in a batch
103 const int W = a.shape()[3]; // Number of batched-gemm (Winograd case)
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100104
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +0100105 const int a_stride_z = K * M;
106 const int a_stride_w = K * M * D;
107
108 const int b_stride_z = b.shape().num_dimensions() > 2 ? N * K : 0; // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions
109 const int b_stride_w = b.shape().num_dimensions() > 3 ? K * N * D : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions
110
111 const int c_stride_z = N * M;
112 const int c_stride_w = N * M * D;
113
114 const int fixed_point_position = a.fixed_point_position();
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100115 const fixed_point<T> alpha_q(alpha, fixed_point_position);
116 const fixed_point<T> beta_q(beta, fixed_point_position);
117
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +0100118 for(int w = 0; w < W; ++w)
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100119 {
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +0100120 for(int depth = 0; depth < D; ++depth)
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100121 {
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +0100122 const int base_addr_a = depth * a_stride_z + w * a_stride_w;
123 const int base_addr_b = depth * b_stride_z + w * b_stride_w;
124 const int base_addr_c = depth * c_stride_z + w * c_stride_w;
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100125
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +0100126 for(int row = 0; row < M; ++row)
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100127 {
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +0100128 for(int col = 0; col < N; ++col)
129 {
130 fixed_point<promoted_type> acc_q(0, fixed_point_position);
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100131
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +0100132 for(int k = 0; k < K; ++k)
133 {
134 const fixed_point<promoted_type> a0_q(a[base_addr_a + row * K + k], fixed_point_position, true);
135 const fixed_point<promoted_type> b0_q(b[base_addr_b + k * N + col], fixed_point_position, true);
136
137 acc_q = acc_q + (a0_q * b0_q);
138 }
139
140 // Finalize the result: alpha * A * B + beta * C
141 const fixed_point<T> c0_q(c[base_addr_c + col + row * N], fixed_point_position, true);
142
143 fixed_point<T> res_q(acc_q);
144 res_q = alpha_q * res_q;
145 res_q = res_q + (beta_q * c0_q);
146
147 // Store the result
148 dst[base_addr_c + col + row * N] = res_q.raw();
149 }
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100150 }
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100151 }
152 }
153
154 return dst;
155}
156
157template SimpleTensor<float> gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta);
Georgios Pinitas583137c2017-08-31 18:12:42 +0100158template SimpleTensor<half> gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100159template SimpleTensor<qint8_t> gemm(const SimpleTensor<qint8_t> &a, const SimpleTensor<qint8_t> &b, const SimpleTensor<qint8_t> &c, float alpha, float beta);
160template SimpleTensor<qint16_t> gemm(const SimpleTensor<qint16_t> &a, const SimpleTensor<qint16_t> &b, const SimpleTensor<qint16_t> &c, float alpha, float beta);
161} // namespace reference
162} // namespace validation
163} // namespace test
164} // namespace arm_compute