COMPMID-2378: Sanitize GEMM configuration for NEON

Change-Id: I7859b82b2059e14685f8792424648ac5eacd67f1
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1418
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/core/Dimensions.h b/arm_compute/core/Dimensions.h
index 0a9264f..9c38c60 100644
--- a/arm_compute/core/Dimensions.h
+++ b/arm_compute/core/Dimensions.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -166,6 +166,27 @@
         collapse(num_dimensions() - start, start);
     }
 
+    /** Remove dimension of a given index
+     *
+     * @note If index is greater than the number of dimensions no operation is performed
+     *
+     * @param[in] idx Dimension index to remove
+     */
+    void remove(size_t idx)
+    {
+        ARM_COMPUTE_ERROR_ON(_num_dimensions < 1);
+        if(idx >= _num_dimensions)
+        {
+            return;
+        }
+
+        std::copy(_id.begin() + idx + 1, _id.end(), _id.begin() + idx);
+        _num_dimensions--;
+
+        // Make sure all empty dimensions are filled with 0
+        std::fill(_id.begin() + _num_dimensions, _id.end(), 0);
+    }
+
     /** Returns a read/write iterator that points to the first element in the dimension array.
      *
      * @return an iterator.
diff --git a/arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h b/arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h
index 63178a7..352f73d 100644
--- a/arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h
+++ b/arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -45,7 +45,7 @@
         unsigned int multis{ 0 };  /**< Number of "multi" GEMMs (unique A, B and C). */
     };
 
-    static Params extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c);
+    static Params extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c, const GEMMInfo &gemm_info);
 
     /** Constructor */
     INEGEMMWrapperKernel();
@@ -61,13 +61,14 @@
      *
      * @note The input and output tensor must have the same dimensions
      *
-     * @param[in]  a     Input tensor (Matrix A)
-     * @param[in]  b     Input tensor (Matrix B)
-     * @param[out] c     Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
-     * @param[in]  alpha Scalar multiplier to apply to AB matrix product.
-     * @param[in]  beta  Scalar multiplier to apply to input C matrix before adding product.
+     * @param[in]  a         Input tensor (Matrix A)
+     * @param[in]  b         Input tensor (Matrix B)
+     * @param[out] c         Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+     * @param[in]  alpha     Scalar multiplier to apply to AB matrix product.
+     * @param[in]  beta      Scalar multiplier to apply to input C matrix before adding product.
+     * @param[in]  gemm_info GEMM meta-data
      */
-    void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta);
+    void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info);
 
     // Inherited methods overridden:
     void run(const Window &window, const ThreadInfo &info) override;
@@ -95,6 +96,7 @@
     const ITensor *_b;
     ITensor       *_c;
     Params         _params;
+    GEMMInfo       _gemm_info;
 
 private:
     Window      _window3d;
diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h
index e2b849a..40b6f5d 100644
--- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h
+++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h
@@ -95,31 +95,32 @@
 public:
     /** Configure the matrix multiplication: C = alpha * A * B + beta * C
      *
-     * @param[in]     prepared_a       Already reshaped matrix A.
-     * @param[in]     transformed_b    Already reshaped matrix B.
-     * @param[out]    tmp_c            Temporary buffer to be used to store intermediate results.
-     * @param[in,out] c                Result matrix C.
-     * @param[in]     block_walker     Window containing iteration information for the M and batch dimensions.
-     * @param[in]     block_sizes      Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes).
-     * @param[in]     params           M, N, K sizes.
-     * @param[in]     is_pretransposed Is B also pretransposed ?
-     * @param[in]     alpha            Alpha value
-     * @param[in]     beta             Beta value
-     * @param[in]     max_num_threads  Maximum number of threads that might be used for the calculations.
+     * @param[in]     prepared_a      Already reshaped matrix A.
+     * @param[in]     transformed_b   Already reshaped matrix B.
+     * @param[out]    tmp_c           Temporary buffer to be used to store intermediate results.
+     * @param[in,out] c               Result matrix C.
+     * @param[in]     block_walker    Window containing iteration information for the M and batch dimensions.
+     * @param[in]     block_sizes     Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes).
+     * @param[in]     params          M, N, K sizes.
+     * @param[in]     gemm_info       GEMM meta-data
+     * @param[in]     alpha           Alpha value
+     * @param[in]     beta            Beta value
+     * @param[in]     max_num_threads Maximum number of threads that might be used for the calculations.
      */
     void configure(const ITensor *prepared_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker, const BlockSizes &block_sizes,
-                   const INEGEMMWrapperKernel::Params &params, bool b_is_pretransposed, float alpha, float beta, unsigned int max_num_threads)
+                   const INEGEMMWrapperKernel::Params &params, const GEMMInfo &gemm_info, float alpha, float beta, unsigned int max_num_threads)
     {
-        _prepared_a         = prepared_a;
-        _transformed_b      = transformed_b;
-        _tmp_c              = tmp_c;
-        _c                  = c;
-        _block_walker       = block_walker;
-        _block_sizes        = block_sizes;
-        _params             = params;
-        _b_is_pretransposed = b_is_pretransposed;
-        _alpha              = alpha;
-        _beta               = beta;
+        _prepared_a          = prepared_a;
+        _transformed_b       = transformed_b;
+        _tmp_c               = tmp_c;
+        _c                   = c;
+        _block_walker        = block_walker;
+        _block_sizes         = block_sizes;
+        _params              = params;
+        _b_is_pretransposed  = gemm_info.pretranpose_B();
+        _reinterpret_c_as_3d = gemm_info.depth_output_gemm3d() != 0;
+        _alpha               = alpha;
+        _beta                = beta;
 
         auto_init_if_empty(*_tmp_c->info(), c->info()->clone()->set_tensor_shape(TensorShape{ _block_sizes.x_block * strategy::out_height(), max_num_threads }));
     }
@@ -133,6 +134,14 @@
         TensorAccessor<typename strategy::result_type>  c(*_c);
         TensorAccessor<typename strategy::result_type>  tmp_c(*_tmp_c);
 
+        // Handle 3d output re-interpretation
+        if(_reinterpret_c_as_3d)
+        {
+            Strides c_strides_as_3d = _c->info()->strides_in_bytes();
+            c_strides_as_3d.remove(Window::DimZ);
+            c.set_strides(c_strides_as_3d);
+        }
+
         int                              prev_batch = -1;
         typename strategy::operand_type *a_ptr      = nullptr;
         auto window_iterator                        = arm_compute::create_window_iterator(batch_window, start_offset, end_offset, [&](const Coordinates & id)
@@ -216,9 +225,9 @@
     INEGEMMWrapperKernel::Params   _params{};
     Window                         _block_walker{};
     bool                           _b_is_pretransposed{ false };
+    bool                           _reinterpret_c_as_3d{ false };
     typename strategy::result_type _alpha{};
     typename strategy::result_type _beta{};
 };
-
 } // namespace arm_compute
 #endif /* __ARM_COMPUTE_NEGEMMINTERLEAVEDMATRIXMULTIPLYWRAPPER_H__ */
diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
index 5d6cd02..b18d327 100644
--- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
+++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
@@ -87,20 +87,22 @@
 public:
     /** Configure the reshape A routine.
      *
-     * @param[in]  a             Input matrix A.
-     * @param[out] transformed_a Reshaped matrix A.
-     * @param[in]  transpose_a   Also transpose A ?
-     * @param[in]  block_walker  Window representing the layout of the matrix's blocks
-     * @param[in]  params        M, N, K sizes.
+     * @param[in]  a                   Input matrix A.
+     * @param[out] transformed_a       Reshaped matrix A.
+     * @param[in]  transpose_a         Also transpose A ?
+     * @param[in]  reinterpret_a_as_3d Re-interpret as 3D ?
+     * @param[in]  block_walker        Window representing the layout of the matrix's blocks
+     * @param[in]  params              M, N, K sizes.
      */
-    void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, const Window &block_walker, const INEGEMMWrapperKernel::Params &params)
+    void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, bool reinterpret_a_as_3d, const Window &block_walker, const INEGEMMWrapperKernel::Params &params)
     {
-        _a              = a;
-        _transformed_a  = transformed_a;
-        _transpose_a    = transpose_a;
-        _Ksize          = params.K;
-        _Msize          = params.M;
-        _k_multi_window = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension
+        _a                   = a;
+        _transformed_a       = transformed_a;
+        _transpose_a         = transpose_a;
+        _reinterpret_a_as_3d = reinterpret_a_as_3d;
+        _Ksize               = params.K;
+        _Msize               = params.M;
+        _k_multi_window      = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension
     }
 
     // Inherited methods overridden:
@@ -110,12 +112,12 @@
         TensorAccessor<typename strategy::operand_type> a(*_a);
         TensorAccessor<typename strategy::operand_type> transformed_a(*_transformed_a);
 
-        if(_a->info()->data_layout() == DataLayout::NHWC)
+        // Handle 3d input re-interpretation
+        if(_reinterpret_a_as_3d)
         {
-            // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
-            // the relevant multiple of the row stride.
-            const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _Msize;
-            a.set_stride(2, nhwc_batch_stride);
+            Strides a_strides_as_3d = _a->info()->strides_in_bytes();
+            a_strides_as_3d.remove(Window::DimZ);
+            a.set_strides(a_strides_as_3d);
         }
 
         unsigned int last_m = 0;
@@ -164,8 +166,8 @@
     unsigned int _Msize{ 0 };
     unsigned int _Ksize{ 0 };
     bool         _transpose_a{ false };
+    bool         _reinterpret_a_as_3d{ false };
     Window       _k_multi_window{};
 };
-
 } // namespace arm_compute
 #endif /* __ARM_COMPUTE_NEGEMMINTERLEAVEDTRANSFORMAWRAPPER_H__ */
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index ad679d6..b4d94ec 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -1765,9 +1765,17 @@
 {
 public:
     /** Default constructor */
-    GEMMInfo()
-        : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(true), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false), _retain_internal_weights(false), _gemmlowp_output_stage(),
-          _fp_mixed_precision(false), _broadcast_bias(false)
+    GEMMInfo() noexcept
+        : _is_a_reshaped(false),
+          _is_b_reshaped(false),
+          _reshape_b_only_on_first_run(true),
+          _depth_output_gemm3d(0),
+          _reinterpret_input_as_3d(false),
+          _retain_internal_weights(false),
+          _gemmlowp_output_stage(),
+          _fp_mixed_precision(false),
+          _broadcast_bias(false),
+          _pretranpose_B(true)
     {
     }
     /** Constructor
@@ -1785,10 +1793,17 @@
      * @param[in] broadcast_bias              (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
      */
     GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false,
-             GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool broadcast_bias = false)
-        : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d),
-          _reinterpret_input_as_3d(reinterpret_input_as_3d), _retain_internal_weights(retain_internal_weights), _gemmlowp_output_stage(gemmlowp_output_stage), _fp_mixed_precision(fp_mixed_precision),
-          _broadcast_bias(broadcast_bias)
+             GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool broadcast_bias = false) noexcept
+        : _is_a_reshaped(is_a_reshaped),
+          _is_b_reshaped(is_b_reshaped),
+          _reshape_b_only_on_first_run(reshape_b_only_on_first_run),
+          _depth_output_gemm3d(depth_output_gemm3d),
+          _reinterpret_input_as_3d(reinterpret_input_as_3d),
+          _retain_internal_weights(retain_internal_weights),
+          _gemmlowp_output_stage(gemmlowp_output_stage),
+          _fp_mixed_precision(fp_mixed_precision),
+          _broadcast_bias(broadcast_bias),
+          _pretranpose_B(reshape_b_only_on_first_run)
     {
     }
     /** Flag which specifies if the matrix A has been reshaped
@@ -1865,17 +1880,34 @@
     {
         return _broadcast_bias;
     };
+    /** Flag which specifies whether b should be pre-transposed if supported.
+     *
+     * @return True if b should be pre-transposed else false.
+     */
+    bool pretranpose_B() const
+    {
+        return _pretranpose_B;
+    };
+    /** Set pre-transpose b flag
+     *
+     * @param[in] flag Flag to set
+     */
+    void set_pretranpose_B(bool flag)
+    {
+        _pretranpose_B = flag;
+    }
 
 private:
-    const bool                    _is_a_reshaped;
-    const bool                    _is_b_reshaped;
-    const bool                    _reshape_b_only_on_first_run;
-    const int                     _depth_output_gemm3d;
-    const bool                    _reinterpret_input_as_3d;
-    const bool                    _retain_internal_weights;
-    const GEMMLowpOutputStageInfo _gemmlowp_output_stage;
-    const bool                    _fp_mixed_precision;
-    const bool                    _broadcast_bias;
+    bool                    _is_a_reshaped;
+    bool                    _is_b_reshaped;
+    bool                    _reshape_b_only_on_first_run;
+    int                     _depth_output_gemm3d;
+    bool                    _reinterpret_input_as_3d;
+    bool                    _retain_internal_weights;
+    GEMMLowpOutputStageInfo _gemmlowp_output_stage;
+    bool                    _fp_mixed_precision;
+    bool                    _broadcast_bias;
+    bool                    _pretranpose_B;
 };
 
 /** Winograd information */
diff --git a/arm_compute/core/WindowIterator.h b/arm_compute/core/WindowIterator.h
index 32d6293..15289b6 100644
--- a/arm_compute/core/WindowIterator.h
+++ b/arm_compute/core/WindowIterator.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -86,6 +86,15 @@
         _strides[dim] = size;
     }
 
+    /** Manually set the strides
+     *
+     * @param[in] strides Strides to set
+     */
+    void set_strides(const Strides &strides)
+    {
+        _strides = strides;
+    }
+
     /** Returns a pointer to the element at coordinates (x,y,z,w)
      *
      * @param[in] x X coordinates
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h b/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
index 2fc2cf4..b5a2978 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -64,17 +64,17 @@
 
     /** If supported create the ACL function corresponding to the GemmMethod provided to process the other passed parameters
      *
-     * @param[in]  method             GemmMethod to use to perform the matrix multiplication.
-     * @param[in]  a                  Input tensor (Matrix A).
-     * @param[in]  b                  Input tensor (Matrix B).
-     * @param[out] d                  Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
-     * @param[in]  alpha              Scalar multiplier to apply to AB matrix product.
-     * @param[in]  beta               Scalar multiplier to apply to input D matrix before adding product.
-     * @param[in]  pretransposed_hint Can the B tensor can be pretransposed (ie shared across invocations)?
+     * @param[in]  method    GemmMethod to use to perform the matrix multiplication.
+     * @param[in]  a         Input tensor (Matrix A).
+     * @param[in]  b         Input tensor (Matrix B).
+     * @param[out] d         Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+     * @param[in]  alpha     Scalar multiplier to apply to AB matrix product.
+     * @param[in]  beta      Scalar multiplier to apply to input D matrix before adding product.
+     * @param[in]  gemm_info GEMM meta-data
      *
      * @return True if the method is supported and the function was successfully created, false otherwise.
      */
-    bool create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint);
+    bool create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info);
 
     /** Interface for the arm_gemm fallback */
     std::unique_ptr<IFallback>      _arm_gemm;
@@ -83,27 +83,27 @@
 public:
     /** If supported create an ACL function else fallback to the arm_gemm function.
      *
-     * @param[in]  a                 Input tensor (Matrix A)
-     * @param[in]  b                 Input tensor (Matrix B)
-     * @param[out] d                 Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
-     * @param[in]  alpha             Scalar multiplier to apply to AB matrix product.
-     * @param[in]  beta              Scalar multiplier to apply to input D matrix before adding product.
-     * @param[in]  pretranspose_hint Can the B tensor can be pretransposed (ie shared across invocations)?
+     * @param[in]  a         Input tensor (Matrix A)
+     * @param[in]  b         Input tensor (Matrix B)
+     * @param[out] d         Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+     * @param[in]  alpha     Scalar multiplier to apply to AB matrix product.
+     * @param[in]  beta      Scalar multiplier to apply to input D matrix before adding product.
+     * @param[in]  gemm_info GEMM meta-data
      */
-    void configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint);
+    void configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info);
 
     /** Indicates whether or not this function can be used to process the given parameters.
      *
-     * @param[in] a                 Input tensor (Matrix A)
-     * @param[in] b                 Input tensor (Matrix B)
-     * @param[in] d                 Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
-     * @param[in] alpha             Scalar multiplier to apply to AB matrix product.
-     * @param[in] beta              Scalar multiplier to apply to input D matrix before adding product.
-     * @param[in] pretranspose_hint Can the B tensor can be pretransposed (ie shared across invocations)?
+     * @param[in] a         Input tensor (Matrix A)
+     * @param[in] b         Input tensor (Matrix B)
+     * @param[in] d         Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+     * @param[in] alpha     Scalar multiplier to apply to AB matrix product.
+     * @param[in] beta      Scalar multiplier to apply to input D matrix before adding product.
+     * @param[in] gemm_info GEMM meta-data
      *
      * @return a status.
      */
-    static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, bool pretranspose_hint);
+    static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info);
     /** Was the function successfully configured ?
      *
      * @return True if the function is configured and ready to run
diff --git a/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h b/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h
index 9495647..ad89e1f 100644
--- a/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h
+++ b/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h
@@ -104,14 +104,14 @@
      *
      * @note The input and output tensor must have the same dimensions
      *
-     * @param[in]  a              Input tensor (Matrix A)
-     * @param[in]  b              Input tensor (Matrix B)
-     * @param[out] c              Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
-     * @param[in]  alpha          Scalar multiplier to apply to AB matrix product.
-     * @param[in]  beta           Scalar multiplier to apply to input C matrix before adding product.
-     * @param[in]  pretranspose_b If true, pretranspose B once during the prepare() stage instead of on the fly every time.
+     * @param[in]  a         Input tensor (Matrix A)
+     * @param[in]  b         Input tensor (Matrix B)
+     * @param[out] c         Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+     * @param[in]  alpha     Scalar multiplier to apply to AB matrix product.
+     * @param[in]  beta      Scalar multiplier to apply to input C matrix before adding product.
+     * @param[in]  gemm_info GEMM meta-data
      */
-    void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, bool pretranspose_b);
+    void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info);
 
     // Inherited methods overridden:
     void run() override;