COMPMID-1964: Implement CLGEMMMatrixMultiplyReshapedOnlyRHS - Not transposed
Change-Id: I6b7f7a406fcb8d64adb221a24d7983e11bcb391e
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/846
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
index cbbc592..83051d2 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
@@ -103,7 +103,7 @@
const auto h0_values_precommit = framework::dataset::make("H0", 1, 3);
/** M0 values to test - Nightly */
-const auto m0_values_nightly = framework::dataset::make("M0", 2, 8);
+const auto m0_values_nightly = framework::dataset::make("M0", 1, 8);
/** N0 values to test - Nightly */
const auto n0_values_nightly = framework::dataset::make("N0", { 2, 3, 4, 8 });
@@ -118,10 +118,10 @@
const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, false });
/** Transpose values to test with RHS matrix */
-const auto t_values_rhs = framework::dataset::make("transpose_rhs", { true });
+const auto t_values_rhs = framework::dataset::make("transpose_rhs", { true, false });
/** Configuration test */
-void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int h0_value, bool i_value_rhs, DataType data_type)
+void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int h0_value, bool i_value_rhs, bool t_value_rhs, DataType data_type)
{
const unsigned int M = m_value;
const unsigned int N = n_value;
@@ -136,7 +136,7 @@
rhs_info.k0 = k0_value;
rhs_info.h0 = h0_value;
rhs_info.interleave = i_value_rhs;
- rhs_info.transpose = true;
+ rhs_info.transpose = t_value_rhs;
GEMMReshapeInfo gemm_info(M, N, K);
@@ -168,7 +168,7 @@
TEST_SUITE(GEMMMatrixMultiplyReshapedOnlyRHS)
TEST_SUITE(Float)
TEST_SUITE(FP32)
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(
+DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_values,
n_values),
k_values),
@@ -178,9 +178,10 @@
k0_values_precommit),
h0_values_precommit),
i_values_rhs),
-m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, h0_value, i_value_rhs)
+ t_values_rhs),
+m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, h0_value, i_value_rhs, t_value_rhs)
{
- validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, h0_value, i_value_rhs, DataType::F32);
+ validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, h0_value, i_value_rhs, t_value_rhs, DataType::F32);
}
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<float>, framework::DatasetMode::ALL,