blob: 3051f3079b1d6e3c3997c77bd2a067cd8f82638e [file] [log] [blame]
Georgios Pinitasf4e84fb2021-07-08 15:36:07 +01001/*
Matthew Benthamf1aeab92023-05-30 13:35:34 +00002 * Copyright (c) 2017-2023 Arm Limited.
Georgios Pinitasf4e84fb2021-07-08 15:36:07 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_CL_GEMMLOWP_MATRIXMULTIPLY_CORE_H
25#define ARM_COMPUTE_CL_GEMMLOWP_MATRIXMULTIPLY_CORE_H
26
Matthew Benthamf1aeab92023-05-30 13:35:34 +000027#include "arm_compute/core/GEMMInfo.h"
Georgios Pinitasf4e84fb2021-07-08 15:36:07 +010028#include "arm_compute/core/TensorInfo.h"
29#include "arm_compute/runtime/CL/CLTypes.h"
30
Georgios Pinitas7891a732021-08-20 21:39:25 +010031#include "src/gpu/cl/ClCompileContext.h"
32#include "src/gpu/cl/IClOperator.h"
Georgios Pinitasf4e84fb2021-07-08 15:36:07 +010033
34namespace arm_compute
35{
36namespace opencl
37{
38namespace kernels
39{
40// Forward declarations
41class ClCastKernel;
42class ClGemmLowpMatrixMultiplyNativeKernel;
43class ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel;
Freddie Liardete572dff2022-05-16 14:09:10 +010044class ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel;
Georgios Pinitasf4e84fb2021-07-08 15:36:07 +010045class ClGemmReshapeRhsMatrixKernel;
46class ClGemmLowpMatrixAReductionKernel;
47class ClGemmLowpMatrixBReductionKernel;
48class ClGemmLowpOffsetContributionKernel;
49class ClGemmLowpOffsetContributionOutputStageKernel;
50} // namespace kernels
51
52/** Basic function to execute GEMMLowpMatrixMultiplyCore on OpenCL. */
53class ClGemmLowpMatrixMultiplyCore : public IClOperator
54{
55public:
56 ClGemmLowpMatrixMultiplyCore();
57 ~ClGemmLowpMatrixMultiplyCore();
58 /** Initialise the kernel's inputs, output
59 *
60 * Valid data layouts:
61 * - NHWC
62 * - NCHW
63 *
64 * Valid data type configurations:
65 * |src0 |src1 |src2 |dst |
66 * |:--------------|:------------------|:--------|:--------------|
67 * |QASYMM8 |QASYMM8 |S32 |QASYMM8 |
68 * |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 |
69 * |QASYMM8 |QSYMM8 |S32 |QASYMM8 |
70 * |QASYMM8 |QASYMM8 |S32 |S32 |
71 * |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |S32 |
72 * |QASYMM8 |QSYMM8 |S32 |S32 |
73 * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED |
74 * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED |
75 * |QASYMM8_SIGNED |QSYMM8 |S32 |QASYMM8_SIGNED |
76 * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |S32 |
77 * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |S32 |
78 * |QASYMM8_SIGNED |QSYMM8 |S32 |S32 |
79 *
80 * @note GEMMLowp: low precision GEMM kernel. [A * B + C]
81 * This kernel performs the following computations:
82 *
83 * -# Convert a values from 8-bit quantized to int32 and add a_offset to each of them.
84 * -# Convert b values from 8-bit quantized to int32 and add b_offset to each of them.
85 * -# Compute the matrix product of the resulting a * b in int32.
86 * -# Quantize to uint8 if gemm_info.gemmlowp_output_stage != NONE
87 *
88 * @param[in] compile_context The compile context to be used.
89 * @param[in] a First input tensor (Matrix A). Data type supported: QASYMM8/QASYMM8_SIGNED.
90 * @param[in] b Second input tensor (Matrix B). Data type supported: same as @p a
91 * @param[in] c Third input tensor (Matrix C). It can be a nullptr. Data type supported: S32
92 * @param[out] output Output tensor. Data type supported: S32 or QASYMM8/QASYMM8_SIGNED if gemm_info.gemmlowp_output_stage != NONE
93 * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and
94 * if the reshape of matrix B should be executed only for the first run
95 */
96 void configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, const GEMMInfo &gemm_info = GEMMInfo());
97 /** Static function to check if given info will lead to a valid configuration
98 *
99 * Similar to ClGemmLowpMatrixMultiplyCore::configure()
100 *
101 * @return a status
102 */
103 static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, const GEMMInfo &gemm_info = GEMMInfo());
104
105 // Inherited methods overridden:
106 void run(ITensorPack &tensors) override;
107 void prepare(ITensorPack &constants) override;
108 experimental::MemoryRequirements workspace() const override;
109
110private:
111 enum AuxTensorIdx
112 {
Georgios Pinitas529b5a22021-07-27 15:55:30 +0100113 ResultS32 = 0,
Georgios Pinitasf4e84fb2021-07-08 15:36:07 +0100114 RhsQAsymm8,
115 RhsReshape,
Georgios Pinitas529b5a22021-07-27 15:55:30 +0100116 VecSumCol,
117 VecSumRow,
Georgios Pinitasf4e84fb2021-07-08 15:36:07 +0100118 Multipliers,
119 Shifts,
120 Count
121 };
122
123private:
124 // Kernels used
Freddie Liardete572dff2022-05-16 14:09:10 +0100125 std::unique_ptr<kernels::ClCastKernel> _weights_to_qasymm8;
126 std::unique_ptr<kernels::ClGemmLowpMatrixMultiplyNativeKernel> _mm_native_kernel;
127 std::unique_ptr<kernels::ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel> _mm_reshaped_only_rhs_kernel;
128 std::unique_ptr<kernels::ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel> _mm_reshaped_only_rhs_mmul_kernel;
129 std::unique_ptr<kernels::ClGemmReshapeRhsMatrixKernel> _mtx_b_reshape_kernel;
130 std::unique_ptr<kernels::ClGemmLowpMatrixAReductionKernel> _mtx_a_reduction_kernel;
131 std::unique_ptr<kernels::ClGemmLowpMatrixBReductionKernel> _mtx_b_reduction_kernel;
132 std::unique_ptr<kernels::ClGemmLowpOffsetContributionKernel> _offset_contribution_kernel;
133 std::unique_ptr<kernels::ClGemmLowpOffsetContributionOutputStageKernel> _offset_contribution_output_stage_kernel;
Georgios Pinitasf4e84fb2021-07-08 15:36:07 +0100134
135 // Temporary tensors
136 TensorInfo _qasymm8_weights{};
137 TensorInfo _vector_sum_col{};
138 TensorInfo _vector_sum_row{};
139 TensorInfo _tmp_b{};
140 TensorInfo _mm_result_s32{};
141 TensorInfo _gemm_output_stage_multipliers{};
142 TensorInfo _gemm_output_stage_shifts{};
143
Freddie Liardete572dff2022-05-16 14:09:10 +0100144 int32_t _a_offset{ 0 };
145 int32_t _b_offset{ 0 };
146 bool _reshape_b_only_on_first_run{ false };
147 bool _run_output_stage{ false };
148 bool _convert_to_qasymm8{ false };
149 bool _run_offset_contribution{ false };
150 bool _is_prepared{ false };
151 GEMMInfo _gemm_info{};
152 CLGEMMKernelType _gemm_kernel_type{};
Georgios Pinitasf4e84fb2021-07-08 15:36:07 +0100153
154 experimental::MemoryRequirements _aux_mem{};
155};
156} // namespace opencl
157} // namespace arm_compute
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000158#endif /* ARM_COMPUTE_CL_GEMMLOWP_MATRIXMULTIPLY_CORE_H */