blob: a0d13c3e39a0faa3b9ba2f1b4cff13b814cf2839 [file] [log] [blame]
Freddie Liardete572dff2022-05-16 14:09:10 +01001/*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLCast.h"
25#include "arm_compute/runtime/CL/functions/CLReductionOperation.h"
26#include "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.h"
27#include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h"
28#include "tests/CL/CLAccessor.h"
29#include "tests/CL/Helper.h"
30#include "tests/framework/Macros.h"
31#include "tests/framework/datasets/Datasets.h"
32#include "tests/validation/fixtures/GEMMLowpFixture.h"
33
34namespace arm_compute
35{
36namespace test
37{
38namespace validation
39{
40using namespace arm_compute::opencl::kernels;
41
42// Create function for CLGEMMReshapeRHSMatrixKernel
43using CLGEMMReshapeRHSMatrix = CLSynthetizeOperator<opencl::kernels::ClGemmReshapeRhsMatrixKernel>;
44
45// Create function for CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel
46using CLGEMMLowpMatrixMultiplyReshapedOnlyRHS = CLSynthetizeOperator<opencl::kernels::ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel>;
47
48// Fixture for CLGEMMLowpMatrixMultiplyReshapedOnlyRHS
49using CLGEMMLowpMatrixMultiplyReshapedOnlyRHSMMULFixture =
50 GEMMLowpMatrixMultiplyReshapedOnlyRHSMMULValidationFixture<CLTensor, CLAccessor, CLGEMMReshapeRHSMatrix, CLGEMMLowpMatrixMultiplyReshapedOnlyRHS>;
51
52// Fixture for CLGEMMLowpMatrixMultiplyReshapedOnlyRHS
53using CLGEMMLowpMatrixMultiplyReshapedOnlyRHSMMULOutputStageFixtureSigned =
54 GEMMLowpMatrixMultiplyReshapedOnlyRHSMMULOutputStageValidationFixture<int8_t, CLTensor, CLAccessor, CLGEMMReshapeRHSMatrix, CLGEMMLowpMatrixMultiplyReshapedOnlyRHS, CLReductionOperation, CLCast>;
55
56using CLGEMMLowpMatrixMultiplyReshapedOnlyRHSMMULOutputStageFixtureUnsigned =
57 GEMMLowpMatrixMultiplyReshapedOnlyRHSMMULOutputStageValidationFixture<uint8_t, CLTensor, CLAccessor, CLGEMMReshapeRHSMatrix, CLGEMMLowpMatrixMultiplyReshapedOnlyRHS, CLReductionOperation, CLCast>;
58
59namespace
60{
61// *INDENT-OFF*
62// clang-format off
63
64/** M values to test */
65const auto m_values = framework::dataset::make("M", {16, 49});
66
67/** N values to test */
68const auto n_values = framework::dataset::make("N", {16, 259});
69
70/** K values to test */
71const auto k_values = framework::dataset::make("K", {192});
72
73/** Batch size values to test */
74const auto b_values = framework::dataset::make("batch_size", {1, 2});
75
76/** M0 values to test - Precommit */
77const auto m0 = framework::dataset::make("M0", {1, 2, 4});
78
79/** N0 values to test - Precommit */
80const auto n0 = framework::dataset::make("N0", { 1, 4, 8});
81
82/** K0 values to test - Precommit */
83const auto k0 = framework::dataset::make("K0", { 4 });
84
85/** H0 values to test - Precommit */
86const auto h0 = framework::dataset::make("H0", 1);
87
88/** Interleave values to test with RHS matrix */
89const auto i_values_rhs = framework::dataset::make("interleave_rhs", { false });
90
91/** Transpose values to test with RHS matrix */
92const auto t_values_rhs = framework::dataset::make("transpose_rhs", { true });
93
94const auto broadcast_bias = framework::dataset::make("broadcast_bias", {true, false});
95
96} // namespace
97
98TEST_SUITE(CL)
99TEST_SUITE(GEMMLowpMatrixMultiplyReshapedOnlyRhsMMUL)
100FIXTURE_DATA_TEST_CASE(Signed, CLGEMMLowpMatrixMultiplyReshapedOnlyRHSMMULFixture, framework::DatasetMode::ALL,
101 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
102 m_values,
103 n_values),
104 k_values),
105 b_values),
106 m0),
107 n0),
108 k0),
109 h0),
110 i_values_rhs),
111 t_values_rhs),
112 framework::dataset::make("DataType", { DataType::QASYMM8_SIGNED })))
113{
114 // Validate output
115 if(arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()))
116 {
117 validate(CLAccessor(_target), _reference);
118 }
119 else
120 {
121 ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
122 framework::ARM_COMPUTE_PRINT_INFO();
123 }
124}
125FIXTURE_DATA_TEST_CASE(Unsigned, CLGEMMLowpMatrixMultiplyReshapedOnlyRHSMMULFixture, framework::DatasetMode::ALL,
126 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
127 m_values,
128 n_values),
129 k_values),
130 b_values),
131 m0),
132 n0),
133 k0),
134 h0),
135 i_values_rhs),
136 t_values_rhs),
137 framework::dataset::make("DataType", { DataType::QASYMM8})))
138{
139 // Validate output
140 if(arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()))
141 {
142 validate(CLAccessor(_target), _reference);
143 }
144 else
145 {
146 ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
147 framework::ARM_COMPUTE_PRINT_INFO();
148 }
149}
150FIXTURE_DATA_TEST_CASE(OutputStageSigned, CLGEMMLowpMatrixMultiplyReshapedOnlyRHSMMULOutputStageFixtureSigned, framework::DatasetMode::ALL,
151 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
152 m_values,
153 n_values),
154 k_values),
155 b_values),
156 m0),
157 n0),
158 k0),
159 h0),
160 i_values_rhs),
161 t_values_rhs),
162 broadcast_bias),
163 framework::dataset::make("DataType", { DataType::QASYMM8_SIGNED})))
164{
165 // Validate output
166 if(arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()))
167 {
168 validate(CLAccessor(_target), _reference);
169 }
170 else
171 {
172 ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
173 framework::ARM_COMPUTE_PRINT_INFO();
174 }
175}
176FIXTURE_DATA_TEST_CASE(OutputStageUnsigned, CLGEMMLowpMatrixMultiplyReshapedOnlyRHSMMULOutputStageFixtureUnsigned, framework::DatasetMode::ALL,
177 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
178 m_values,
179 n_values),
180 k_values),
181 b_values),
182 m0),
183 n0),
184 k0),
185 h0),
186 i_values_rhs),
187 t_values_rhs),
188 broadcast_bias),
189 framework::dataset::make("DataType", { DataType::QASYMM8})))
190{
191 // Validate output
192 if(arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()))
193 {
194 validate(CLAccessor(_target), _reference);
195 }
196 else
197 {
198 ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
199 framework::ARM_COMPUTE_PRINT_INFO();
200 }
201}
202TEST_SUITE_END() // GEMMLowpMatrixMultiplyReshapedOnlyRhsMMUL
203TEST_SUITE_END() // CL
204} // namespace validation
205} // namespace test
206} // namespace arm_compute