blob: 74161e330e6585925774295e0e49794d9e443f11 [file] [log] [blame]
Anthony Barbierc8e84b52018-07-17 16:48:42 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2018-2019 Arm Limited.
Anthony Barbierc8e84b52018-07-17 16:48:42 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouf4643372019-11-29 16:17:13 +000024#ifndef ARM_COMPUTE_INEGEMMWRAPPERKERNEL_H
25#define ARM_COMPUTE_INEGEMMWRAPPERKERNEL_H
Anthony Barbierc8e84b52018-07-17 16:48:42 +010026
27#include "arm_compute/core/NEON/INEKernel.h"
28
29namespace arm_compute
30{
31class ITensor;
32
33/** Common interface for all the arm_gemm Gemms
34 */
35class INEGEMMWrapperKernel : public INEKernel
36{
37public:
38 /** Parameters defining the dimensions of the matrices being multiplied */
39 struct Params
40 {
Anthony Barbier3d677cc2018-07-23 16:42:59 +010041 unsigned int M{ 0 }; /**< Rows in output matrix C (and input matrix A). */
42 unsigned int N{ 0 }; /**< Columns in output matrix C (and input matrix B). */
43 unsigned int K{ 0 }; /**< Columns of input matrix A (= rows of input matrix B). */
44 unsigned int batches{ 0 }; /**< Number of "batched" GEMMs (unique A and C, shared B). */
45 unsigned int multis{ 0 }; /**< Number of "multi" GEMMs (unique A, B and C). */
Anthony Barbierc8e84b52018-07-17 16:48:42 +010046 };
47
Georgios Pinitas37d080f2019-06-21 18:43:12 +010048 static Params extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c, const GEMMInfo &gemm_info);
Anthony Barbierc8e84b52018-07-17 16:48:42 +010049
50 /** Constructor */
51 INEGEMMWrapperKernel();
52 /** Prevent instances of this class from being copied */
53 INEGEMMWrapperKernel(const INEGEMMWrapperKernel &) = delete;
54 /** Prevent instances of this class from being copied */
55 INEGEMMWrapperKernel &operator=(const INEGEMMWrapperKernel &) = delete;
56 /** Allow instances of this class to be moved */
57 INEGEMMWrapperKernel(INEGEMMWrapperKernel &&) = default;
58 /** Allow instances of this class to be moved */
59 INEGEMMWrapperKernel &operator=(INEGEMMWrapperKernel &&) = default;
60 /** Initialise the kernel's input and output.
61 *
62 * @note The input and output tensor must have the same dimensions
63 *
Georgios Pinitas37d080f2019-06-21 18:43:12 +010064 * @param[in] a Input tensor (Matrix A)
65 * @param[in] b Input tensor (Matrix B)
66 * @param[out] c Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
67 * @param[in] alpha Scalar multiplier to apply to AB matrix product.
68 * @param[in] beta Scalar multiplier to apply to input C matrix before adding product.
69 * @param[in] gemm_info GEMM meta-data
Anthony Barbierc8e84b52018-07-17 16:48:42 +010070 */
Georgios Pinitas37d080f2019-06-21 18:43:12 +010071 void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info);
Anthony Barbierc8e84b52018-07-17 16:48:42 +010072
73 // Inherited methods overridden:
74 void run(const Window &window, const ThreadInfo &info) override;
75
76protected:
77 /** Called as part of configure() after _a, _b, _c and _params have been set.
78 *
79 * @param[in] alpha Scalar multiplier to apply to AB matrix product.
80 * @param[in] beta Scalar multiplier to apply to input C matrix before adding product.
81 *
82 * @return A 3D execution window.
83 */
84 virtual Window configure_internal(float alpha, float beta) = 0;
85
86 /** Run the kernel from the start to the end offset in window.
87 *
88 * @param[in] window Window to use for the iteration
89 * @param[in] start_offset Where to start iterating from (In Window coordinates)
90 * @param[in] end_offset Where to stop iterating (In Window coordinates).
91 * @param[in] info Info about executing thread and CPU.
92 */
93 virtual void run_internal(const Window &window, const Coordinates &start_offset, const Coordinates &end_offset, const ThreadInfo &info) = 0;
94
95 const ITensor *_a;
96 const ITensor *_b;
97 ITensor *_c;
98 Params _params;
Georgios Pinitas37d080f2019-06-21 18:43:12 +010099 GEMMInfo _gemm_info;
Anthony Barbierc8e84b52018-07-17 16:48:42 +0100100
101private:
102 Window _window3d;
103 TensorShape _window_shape;
104};
105
106} // namespace arm_compute
107
Michalis Spyrouf4643372019-11-29 16:17:13 +0000108#endif /* ARM_COMPUTE_INEGEMMRAPPERKERNEL_H */