blob: 1965e3f97b75112e7f4bd933f7ab3613145e0848 [file] [log] [blame]
Georgios Pinitasf4e84fb2021-07-08 15:36:07 +01001/*
2 * Copyright (c) 2017-2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_CL_GEMMLOWP_MATRIXMULTIPLY_CORE_H
25#define ARM_COMPUTE_CL_GEMMLOWP_MATRIXMULTIPLY_CORE_H
26
27#include "arm_compute/core/TensorInfo.h"
28#include "arm_compute/runtime/CL/CLTypes.h"
29
Georgios Pinitas7891a732021-08-20 21:39:25 +010030#include "src/gpu/cl/ClCompileContext.h"
31#include "src/gpu/cl/IClOperator.h"
Georgios Pinitasf4e84fb2021-07-08 15:36:07 +010032
33namespace arm_compute
34{
35namespace opencl
36{
37namespace kernels
38{
39// Forward declarations
40class ClCastKernel;
41class ClGemmLowpMatrixMultiplyNativeKernel;
42class ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel;
43class ClGemmReshapeRhsMatrixKernel;
44class ClGemmLowpMatrixAReductionKernel;
45class ClGemmLowpMatrixBReductionKernel;
46class ClGemmLowpOffsetContributionKernel;
47class ClGemmLowpOffsetContributionOutputStageKernel;
48} // namespace kernels
49
50/** Basic function to execute GEMMLowpMatrixMultiplyCore on OpenCL. */
51class ClGemmLowpMatrixMultiplyCore : public IClOperator
52{
53public:
54 ClGemmLowpMatrixMultiplyCore();
55 ~ClGemmLowpMatrixMultiplyCore();
56 /** Initialise the kernel's inputs, output
57 *
58 * Valid data layouts:
59 * - NHWC
60 * - NCHW
61 *
62 * Valid data type configurations:
63 * |src0 |src1 |src2 |dst |
64 * |:--------------|:------------------|:--------|:--------------|
65 * |QASYMM8 |QASYMM8 |S32 |QASYMM8 |
66 * |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 |
67 * |QASYMM8 |QSYMM8 |S32 |QASYMM8 |
68 * |QASYMM8 |QASYMM8 |S32 |S32 |
69 * |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |S32 |
70 * |QASYMM8 |QSYMM8 |S32 |S32 |
71 * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED |
72 * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED |
73 * |QASYMM8_SIGNED |QSYMM8 |S32 |QASYMM8_SIGNED |
74 * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |S32 |
75 * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |S32 |
76 * |QASYMM8_SIGNED |QSYMM8 |S32 |S32 |
77 *
78 * @note GEMMLowp: low precision GEMM kernel. [A * B + C]
79 * This kernel performs the following computations:
80 *
81 * -# Convert a values from 8-bit quantized to int32 and add a_offset to each of them.
82 * -# Convert b values from 8-bit quantized to int32 and add b_offset to each of them.
83 * -# Compute the matrix product of the resulting a * b in int32.
84 * -# Quantize to uint8 if gemm_info.gemmlowp_output_stage != NONE
85 *
86 * @param[in] compile_context The compile context to be used.
87 * @param[in] a First input tensor (Matrix A). Data type supported: QASYMM8/QASYMM8_SIGNED.
88 * @param[in] b Second input tensor (Matrix B). Data type supported: same as @p a
89 * @param[in] c Third input tensor (Matrix C). It can be a nullptr. Data type supported: S32
90 * @param[out] output Output tensor. Data type supported: S32 or QASYMM8/QASYMM8_SIGNED if gemm_info.gemmlowp_output_stage != NONE
91 * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and
92 * if the reshape of matrix B should be executed only for the first run
93 */
94 void configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, const GEMMInfo &gemm_info = GEMMInfo());
95 /** Static function to check if given info will lead to a valid configuration
96 *
97 * Similar to ClGemmLowpMatrixMultiplyCore::configure()
98 *
99 * @return a status
100 */
101 static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, const GEMMInfo &gemm_info = GEMMInfo());
102
103 // Inherited methods overridden:
104 void run(ITensorPack &tensors) override;
105 void prepare(ITensorPack &constants) override;
106 experimental::MemoryRequirements workspace() const override;
107
108private:
109 enum AuxTensorIdx
110 {
Georgios Pinitas529b5a22021-07-27 15:55:30 +0100111 ResultS32 = 0,
Georgios Pinitasf4e84fb2021-07-08 15:36:07 +0100112 RhsQAsymm8,
113 RhsReshape,
Georgios Pinitas529b5a22021-07-27 15:55:30 +0100114 VecSumCol,
115 VecSumRow,
Georgios Pinitasf4e84fb2021-07-08 15:36:07 +0100116 Multipliers,
117 Shifts,
118 Count
119 };
120
121private:
122 // Kernels used
123 std::unique_ptr<kernels::ClCastKernel> _weights_to_qasymm8;
124 std::unique_ptr<kernels::ClGemmLowpMatrixMultiplyNativeKernel> _mm_native_kernel;
125 std::unique_ptr<kernels::ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel> _mm_reshaped_only_rhs_kernel;
126 std::unique_ptr<kernels::ClGemmReshapeRhsMatrixKernel> _mtx_b_reshape_kernel;
127 std::unique_ptr<kernels::ClGemmLowpMatrixAReductionKernel> _mtx_a_reduction_kernel;
128 std::unique_ptr<kernels::ClGemmLowpMatrixBReductionKernel> _mtx_b_reduction_kernel;
129 std::unique_ptr<kernels::ClGemmLowpOffsetContributionKernel> _offset_contribution_kernel;
130 std::unique_ptr<kernels::ClGemmLowpOffsetContributionOutputStageKernel> _offset_contribution_output_stage_kernel;
131
132 // Temporary tensors
133 TensorInfo _qasymm8_weights{};
134 TensorInfo _vector_sum_col{};
135 TensorInfo _vector_sum_row{};
136 TensorInfo _tmp_b{};
137 TensorInfo _mm_result_s32{};
138 TensorInfo _gemm_output_stage_multipliers{};
139 TensorInfo _gemm_output_stage_shifts{};
140
141 int32_t _a_offset{ 0 };
142 int32_t _b_offset{ 0 };
143 bool _is_gemm_reshaped{ true };
144 bool _reshape_b_only_on_first_run{ false };
145 bool _run_output_stage{ false };
146 bool _convert_to_qasymm8{ false };
147 bool _run_offset_contribution{ false };
148 bool _is_prepared{ false };
149 GEMMInfo _gemm_info{};
150
151 experimental::MemoryRequirements _aux_mem{};
152};
153} // namespace opencl
154} // namespace arm_compute
155#endif /* ARM_COMPUTE_CL_GEMMLOWP_MATRIXMULTIPLY_CORE_H */