COMPMID-2002: Implement CLGEMMLowpMatrixMultiplyReshapedOnlyRHS - Transposed

Change-Id: I3907d151107766dc34749fe5710d7219e810b39f
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/875
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index 90a4b5c..5793ebd 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -611,6 +611,214 @@
     TensorType            _target{};
     SimpleTensor<int32_t> _reference{};
 };
+
+template <typename TensorType, typename AccessorType, typename ReshapeRHSFunctionType, typename GEMMFunctionType>
+class GEMMLowpMatrixMultiplyReshapedOnlyRHSValidationFixture : public framework::Fixture
+{
+public:
+    template <typename...>
+    void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int h0, bool interleave_rhs, bool transpose_rhs)
+    {
+        GEMMLHSMatrixInfo lhs_info;
+        lhs_info.m0 = m0;
+        lhs_info.k0 = k0;
+
+        GEMMRHSMatrixInfo rhs_info;
+        rhs_info.n0         = n0;
+        rhs_info.k0         = k0;
+        rhs_info.h0         = h0;
+        rhs_info.interleave = interleave_rhs;
+        rhs_info.transpose  = transpose_rhs;
+
+        // Set the tensor shapes for LHS and RHS matrices
+        const TensorShape lhs_shape(k, m, batch_size);
+        const TensorShape rhs_shape(n, k, batch_size);
+
+        _target    = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info);
+        _reference = compute_reference(lhs_shape, rhs_shape);
+    }
+
+protected:
+    template <typename U>
+    void fill(U &&tensor, int i)
+    {
+        // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
+        std::uniform_int_distribution<> distribution(1, 254);
+        library->fill(tensor, distribution, i);
+    }
+
+    TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info)
+    {
+        // Create tensors
+        TensorType lhs = create_tensor<TensorType>(lhs_shape, DataType::QASYMM8, 1);
+        TensorType rhs = create_tensor<TensorType>(rhs_shape, DataType::QASYMM8, 1);
+        TensorType rhs_reshaped;
+        TensorType dst;
+
+        const unsigned int M = lhs_shape[1];
+        const unsigned int N = rhs_shape[0];
+        const unsigned int K = lhs_shape[0];
+
+        // The output tensor will be auto-initialized within the function
+
+        // Create and configure function
+        ReshapeRHSFunctionType reshape_rhs;
+        GEMMFunctionType       gemm;
+        reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info);
+        gemm.configure(&lhs, &rhs_reshaped, &dst, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K));
+
+        ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+        // Allocate tensors
+        lhs.allocator()->allocate();
+        rhs.allocator()->allocate();
+        rhs_reshaped.allocator()->allocate();
+        dst.allocator()->allocate();
+
+        ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT(!rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT(!rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+        // Fill tensors
+        fill(AccessorType(lhs), 0);
+        fill(AccessorType(rhs), 1);
+
+        // Compute GEMM
+        reshape_rhs.run();
+        gemm.run();
+
+        return dst;
+    }
+
+    SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape)
+    {
+        TensorShape dst_shape = lhs_shape;
+        dst_shape[0]          = rhs_shape[0];
+        dst_shape[1]          = lhs_shape[1];
+
+        // Create reference
+        SimpleTensor<uint8_t> lhs{ lhs_shape, DataType::QASYMM8, 1 };
+        SimpleTensor<uint8_t> rhs{ rhs_shape, DataType::QASYMM8, 1 };
+
+        // Fill reference
+        fill(lhs, 0);
+        fill(rhs, 1);
+
+        return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
+    }
+
+    TensorType            _target{};
+    SimpleTensor<int32_t> _reference{};
+};
+
+template <typename TensorType, typename AccessorType, typename ReshapeRHSFunctionType, typename GEMMFunctionType>
+class GEMMLowpMatrixMultiplyReshapedOnlyRHS3DValidationFixture : public framework::Fixture
+{
+public:
+    template <typename...>
+    void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int h0,
+               bool interleave_rhs, bool transpose_rhs)
+    {
+        GEMMLHSMatrixInfo lhs_info;
+        lhs_info.m0 = m0;
+        lhs_info.k0 = k0;
+
+        GEMMRHSMatrixInfo rhs_info;
+        rhs_info.n0         = n0;
+        rhs_info.k0         = k0;
+        rhs_info.h0         = h0;
+        rhs_info.interleave = interleave_rhs;
+        rhs_info.transpose  = transpose_rhs;
+
+        // In case of GEMM3D, m is the product between m_w and m_h
+        const unsigned int m = m_w * m_h;
+
+        // Set the tensor shapes for LHS and RHS matrices
+        const TensorShape lhs_shape(k, m, batch_size);
+        const TensorShape rhs_shape(n, k, batch_size);
+
+        _target    = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, m_h);
+        _reference = compute_reference(lhs_shape, rhs_shape, m_h);
+    }
+
+protected:
+    template <typename U>
+    void fill(U &&tensor, int i)
+    {
+        // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
+        std::uniform_int_distribution<> distribution(1, 254);
+        library->fill(tensor, distribution, i);
+    }
+
+    TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, unsigned int m_h)
+    {
+        // Create tensors
+        TensorType lhs = create_tensor<TensorType>(lhs_shape, DataType::QASYMM8, 1);
+        TensorType rhs = create_tensor<TensorType>(rhs_shape, DataType::QASYMM8, 1);
+        TensorType rhs_reshaped;
+        TensorType dst;
+
+        const unsigned int M = lhs_shape[1];
+        const unsigned int N = rhs_shape[0];
+        const unsigned int K = lhs_shape[0];
+
+        // The output tensor will be auto-initialized within the function
+
+        // Create and configure function
+        ReshapeRHSFunctionType reshape_rhs;
+        GEMMFunctionType       gemm;
+        reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info);
+        gemm.configure(&lhs, &rhs_reshaped, &dst, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h));
+
+        ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+        // Allocate tensors
+        lhs.allocator()->allocate();
+        rhs.allocator()->allocate();
+        rhs_reshaped.allocator()->allocate();
+        dst.allocator()->allocate();
+
+        ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT(!rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT(!rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+        // Fill tensors
+        fill(AccessorType(lhs), 0);
+        fill(AccessorType(rhs), 1);
+
+        // Compute GEMM
+        reshape_rhs.run();
+        gemm.run();
+
+        return dst;
+    }
+
+    SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, unsigned int m_h)
+    {
+        TensorShape dst_shape = lhs_shape;
+        dst_shape.set(0, rhs_shape[0]);
+        dst_shape.set(1, lhs_shape[1] / m_h);
+        dst_shape.set(2, m_h);
+        dst_shape.set(3, lhs_shape[2]);
+
+        // Create reference
+        SimpleTensor<uint8_t> lhs{ lhs_shape, DataType::QASYMM8, 1 };
+        SimpleTensor<uint8_t> rhs{ rhs_shape, DataType::QASYMM8, 1 };
+
+        // Fill reference
+        fill(lhs, 0);
+        fill(rhs, 1);
+
+        return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
+    }
+
+    TensorType            _target{};
+    SimpleTensor<int32_t> _reference{};
+};
 } // namespace validation
 } // namespace test
 } // namespace arm_compute