Remove fixed format strides hack

- Remove hack in CpuGemmAssemblyDispatch.cpp which tried to guess
  strides for fixed format kernels. Instead, expect that strides will
  have been correctly set on weights externally
- Update fixed format test fixtures to set the strides
- If the fixed format uses fast math mode, then weights should be of
  type BFLOAT16. Change the validation logic to accept this.

Resolves: [ONCPUML-1131]

Co-authored-by: Milos Puzovic <Milos.Puzovic@arm.com>
Change-Id: I0f18d8b86b0f639be25fd122fa06a591e90645f2
Signed-off-by: Jonathan Deakin <jonathan.deakin@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8985
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp
index 3172644..1e1598a 100644
--- a/src/cpu/operators/CpuFullyConnected.cpp
+++ b/src/cpu/operators/CpuFullyConnected.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -109,7 +109,7 @@
     return Status{};
 }
 
-Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ActivationLayerInfo &act, bool enable_fast_math)
+Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ActivationLayerInfo &act, bool enable_fast_math, WeightFormat weight_format)
 {
     if(is_data_type_quantized_asymmetric(src->data_type()))
     {
@@ -137,6 +137,8 @@
     else
     {
         GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */);
+        gemm_info.set_weight_format(weight_format);
+        gemm_info.set_fixed_format(weight_format != WeightFormat::UNSPECIFIED);
         gemm_info.set_fast_math(enable_fast_math);
         ARM_COMPUTE_RETURN_ON_ERROR(CpuGemm::validate(src, weights, biases, dst, 1.f, 1.0f, gemm_info));
     }
@@ -240,7 +242,8 @@
                                                            weights,
                                                            biases != nullptr ? biases : nullptr,
                                                            dst,
-                                                           fc_info));
+                                                           fc_info,
+                                                           weights_info));
     ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, fc_info);
 
     _needs_weights_conversion = false;
@@ -352,12 +355,23 @@
 }
 
 Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
-                                   FullyConnectedLayerInfo fc_info)
+                                   FullyConnectedLayerInfo fc_info, const WeightsInfo &weights_info)
 {
     ARM_COMPUTE_UNUSED(fc_info.retain_internal_weights);
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
-    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights, dst);
+
+    if (is_fixed_format_fast_math(weights_info.weight_format()))
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(src, DataType::F32);
+        ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(weights, DataType::BFLOAT16);
+        ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32);
+    }
+    else
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights, dst);
+    }
+
     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
     ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU
                                 && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
@@ -436,7 +450,7 @@
         ARM_COMPUTE_RETURN_ERROR_ON(src->dimension(0) != weights_to_use->dimension(1));
     }
     // Validate matrix multiply kernel
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(src_to_use, weights_to_use, biases, dst, fc_info.activation_info, fc_info.enable_fast_math));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(src_to_use, weights_to_use, biases, dst, fc_info.activation_info, fc_info.enable_fast_math, weights_info.weight_format()));
 
     return Status{};
 }
diff --git a/src/cpu/operators/CpuFullyConnected.h b/src/cpu/operators/CpuFullyConnected.h
index 36511e9..9cd67f2 100644
--- a/src/cpu/operators/CpuFullyConnected.h
+++ b/src/cpu/operators/CpuFullyConnected.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -89,12 +89,12 @@
                    FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo());
     /** Static function to check if given info will lead to a valid configuration of @ref CpuFullyConnected
      *
-     * Similar to @ref CpuFullyConnected
+     * Similar to @ref CpuFullyConnected::configure()
      *
      * @return a status
      */
     static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
-                           FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo());
+                           FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo());
 
     /** Static function that queries whether there exists fixed-format kernel and if it exists it will return in the first argument in what format
      * weights are expected to be reshaped as defined by WeightFormat class. Apart from the first argument the rest of the arguments are the same
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index a17e4f3..545d59f 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -158,7 +158,17 @@
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32);
-    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
+    
+    if (is_fixed_format_fast_math(gemm_info.weight_format()))
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
+        ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
+    }
+    else
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
+    }
+
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index f3a16f1..9bf6ed1 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -427,7 +427,12 @@
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights_info.are_reshaped(), "Weights already reshaped are not supported!");
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::BFLOAT16, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::BFLOAT16, DataType::F16, DataType::F32);
-    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights);
+
+    if (!is_fixed_format(weights_info.weight_format()))
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights);
+    }
+
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups > 1, "Grouping (num_groups != 1) is not supported");
 
     const DataLayout data_layout = src->data_layout();
@@ -493,7 +498,7 @@
     unsigned int mat_weights_cols = weights->dimension(idx_kernels);
     unsigned int mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel);
 
-    weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, append_bias), 1, data_type);
+    weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, append_bias), 1, weights->data_type());
     weights_reshaped_info.set_quantization_info(weights->quantization_info());
     weights_to_use = &weights_reshaped_info;
 
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 8ff81af..bf3ec5a 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2022 Arm Limited.
+ * Copyright (c) 2018-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -430,9 +430,9 @@
         {
             // Fixed format kernels need no pretranspose.
             ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
-            const int  ldb            = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+            const int  ldb            = b->info()->strides_in_bytes().y() / b->info()->element_size();
             const auto in1_ptr        = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
-            const int  multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+            const int  multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
 
             CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
             ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
@@ -470,21 +470,21 @@
     auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2);
     auto d = tensors.get_tensor(TensorType::ACL_DST);
 
-    int       lda = a->info()->strides_in_bytes().y() / sizeof(TypeInput);
+    int       lda = a->info()->strides_in_bytes().y() / a->info()->element_size();
     int       ldb = 0;
-    const int ldd = d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
+    const int ldd = d->info()->strides_in_bytes().y() / d->info()->element_size();
 
     const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d != 0 ? 3 : 2;
     const size_t a_multi_idx = a_batch_idx + 1;
     const size_t d_batch_idx = _gemm_info.depth_output_gemm3d != 0 ? 3 : 2;
     const size_t d_multi_idx = d_batch_idx + 1;
 
-    int       batch_stride_a = a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput);
-    const int batch_stride_d = d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput);
+    int       batch_stride_a = a->info()->strides_in_bytes()[a_batch_idx] / a->info()->element_size();
+    const int batch_stride_d = d->info()->strides_in_bytes()[d_batch_idx] / d->info()->element_size();
 
-    int       multi_stride_a = a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput);
+    int       multi_stride_a = a->info()->strides_in_bytes()[a_multi_idx] / a->info()->element_size();
     int       multi_stride_b = 0;
-    const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput);
+    const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / d->info()->element_size();
 
     auto             in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes());
     const TypeInput *in1_ptr = nullptr;
@@ -493,50 +493,8 @@
     // Check if B is pre-tranposed and de-reference if not
     if(!_gemm_kernel_asm->B_is_pretransposed())
     {
-        ldb                                = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
-        multi_stride_b                     = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
-        const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
-        if(is_fixed_format(wf))
-        {
-            // The 4D tensor of dimension O'HWI' created for the
-            // OHWIo<interleave_by>i<block_by> format is in reality seen
-            // as a 2D tensor at arm_gemm level, where the rows are
-            // O'/<interleave_by> and the columns are <interleave_by> *
-            // H * W * I'.
-            ITensorInfo      *tensor_info     = b->info();
-            const DataLayout  data_layout     = tensor_info->data_layout();
-            const TensorShape tensor_shape    = tensor_info->tensor_shape();
-            const int         tensor_height   = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
-            const int         tensor_width    = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
-            int               tensor_channels = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)];
-            const int         interleave_by   = arm_compute::interleave_by(wf);
-            const int         blocked_by      = arm_compute::block_by(wf);
-            // We need to find a new stride that is distance from the data for one
-            // set of output channels to the next
-            if(ldb == tensor_channels && multi_stride_b == tensor_channels * tensor_width)
-            {
-                // In this case dimensions that are packed are height, width and channel
-                // so we need to stride it by interleave_by
-                if(tensor_channels % blocked_by != 0)
-                {
-                    // We need to pad
-                    tensor_channels = arm_gemm::iceildiv(tensor_channels, blocked_by) * blocked_by;
-                }
-                ldb = interleave_by * tensor_height * tensor_width * tensor_channels;
-            }
-            else if(multi_stride_b == 0 || (ldb == tensor_width && multi_stride_b == tensor_height * tensor_width))
-            {
-                // In this case dimension that is packed is only height
-                // so we need to stride only height by interleave_by
-                ldb = interleave_by * tensor_height;
-            }
-            else
-            {
-                // If dimensions are not packed as above error is thrown
-                // as at the moment other forms of packing are not supported
-                ARM_COMPUTE_ERROR("Unsupported packing for fixed format kernel");
-            }
-        }
+        ldb                                = b->info()->strides_in_bytes().y() / b->info()->element_size();
+        multi_stride_b                     = b->info()->strides_in_bytes().z() / b->info()->element_size();
         in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
     }
 
@@ -551,9 +509,9 @@
         // Pretranspose B if required
         if(_B_pretranspose_required)
         {
-            const int  ldb            = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+            const int  ldb            = b->info()->strides_in_bytes().y() / b->info()->element_size();
             const auto b_ptr          = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
-            const int  multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+            const int  multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
 
             CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true);
             ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
@@ -780,6 +738,11 @@
     {
         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QASYMM8_SIGNED, DataType::S8);
     }
+    else if(is_fixed_format_fast_math(info.weight_format))
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
+        ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
+    }
     else
     {
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);