[ONCPUML-968] Fixed format kernel support in additional APIs

Implements required plumbing in order to be able to ask and execute
fixed format kernels from NEFullyConnected, NEGEMM and NEGEMMConv2d.
These APIs are used to accelerate oneDNN primitives (inner product, matrix
multiplication and indirect GEMM respectively) and without changes it
would not be possible to call fixed format kernels from those oneDNN
primitives.

Change-Id: I27534f0491ce28d0ccb98c19f318bd33dcdf2ff5
Signed-off-by: Milos Puzovic <milos.puzovic@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7999
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index c87c97c..66e1c8a 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -774,10 +774,10 @@
 
 private:
     std::pair<unsigned int, unsigned int> _stride;
-    unsigned int _pad_left;
-    unsigned int _pad_top;
-    unsigned int _pad_right;
-    unsigned int _pad_bottom;
+    unsigned int                          _pad_left;
+    unsigned int                          _pad_top;
+    unsigned int                          _pad_right;
+    unsigned int                          _pad_bottom;
 
     DimensionRoundingType _round_type;
 };
@@ -919,14 +919,14 @@
     }
 
 private:
-    std::vector<float> _min_sizes;
-    std::vector<float> _variances;
-    float              _offset;
-    bool               _flip;
-    bool               _clip;
-    std::vector<float> _max_sizes;
-    std::vector<float> _aspect_ratios;
-    Coordinates2D      _img_size;
+    std::vector<float>   _min_sizes;
+    std::vector<float>   _variances;
+    float                _offset;
+    bool                 _flip;
+    bool                 _clip;
+    std::vector<float>   _max_sizes;
+    std::vector<float>   _aspect_ratios;
+    Coordinates2D        _img_size;
     std::array<float, 2> _steps;
 };
 
@@ -1171,15 +1171,15 @@
     }
 
 private:
-    unsigned int _max_detections;
-    unsigned int _max_classes_per_detection;
-    float        _nms_score_threshold;
-    float        _iou_threshold;
-    unsigned int _num_classes;
+    unsigned int         _max_detections;
+    unsigned int         _max_classes_per_detection;
+    float                _nms_score_threshold;
+    float                _iou_threshold;
+    unsigned int         _num_classes;
     std::array<float, 4> _scales_values;
-    bool         _use_regular_nms;
-    unsigned int _detection_per_class;
-    bool         _dequantize_scores;
+    bool                 _use_regular_nms;
+    unsigned int         _detection_per_class;
+    bool                 _dequantize_scores;
 };
 
 /** Pooling Layer Information struct*/
@@ -1612,13 +1612,13 @@
     }
 
 private:
-    float _img_width;
-    float _img_height;
-    float _scale;
-    bool  _apply_scale;
-    bool  _correct_transform_coords;
+    float                _img_width;
+    float                _img_height;
+    float                _scale;
+    bool                 _apply_scale;
+    bool                 _correct_transform_coords;
     std::array<float, 4> _weights;
-    float _bbox_xform_clip;
+    float                _bbox_xform_clip;
 };
 
 /** Activation Layer Information class */
@@ -2053,6 +2053,11 @@
     {
         return _weight_format;
     }
+    void set_weight_format(arm_compute::WeightFormat weight_format)
+    {
+        _weight_format = weight_format;
+    }
+
     unsigned int kernel_width() const
     {
         return _kernel_width;
@@ -2495,11 +2500,29 @@
         return _fixed_format;
     }
 
+    /** Set fixed-format flag
+     *
+     * @param[in] fixed_format sets whether or not to use fixed-format kernels
+     */
+    void set_fixed_format(bool fixed_format)
+    {
+        _fixed_format = fixed_format;
+    }
+
     arm_compute::WeightFormat weight_format() const
     {
         return _weight_format;
     }
 
+    /** Set weight format to be used
+     *
+     * @param[in] weight_format arm_compute::WeightFormat enumeration
+     */
+    void set_weight_format(arm_compute::WeightFormat weight_format)
+    {
+        _weight_format = weight_format;
+    }
+
 private:
     bool                                    _is_a_reshaped;
     bool                                    _is_b_reshaped;
diff --git a/arm_compute/runtime/FunctionDescriptors.h b/arm_compute/runtime/FunctionDescriptors.h
index face8a6..af79820 100644
--- a/arm_compute/runtime/FunctionDescriptors.h
+++ b/arm_compute/runtime/FunctionDescriptors.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -62,8 +62,9 @@
                const ActivationLayerInfo                     &act_info,
                bool                                           enable_fast_math,
                unsigned int                                   num_groups,
-               const experimental::PostOpList<ITensorInfo *> &post_ops = experimental::PostOpList<ITensorInfo *> {})
-        : conv_info(conv_info), dilation(dilation), act_info(act_info), enable_fast_math(enable_fast_math), num_groups(num_groups), post_ops(post_ops)
+               const experimental::PostOpList<ITensorInfo *> &post_ops     = experimental::PostOpList<ITensorInfo *> {},
+               const WeightsInfo                             &weights_info = WeightsInfo())
+        : conv_info(conv_info), dilation(dilation), act_info(act_info), enable_fast_math(enable_fast_math), num_groups(num_groups), post_ops(post_ops), weights_info(weights_info)
     {
     }
 
@@ -73,6 +74,7 @@
     bool                                    enable_fast_math{ false };
     unsigned int                            num_groups{ 1 };
     experimental::PostOpList<ITensorInfo *> post_ops{};
+    WeightsInfo                             weights_info{};
 };
 
 /** Descriptor used by the 3d Convolution function */
diff --git a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
index aa96716..2b4f848 100644
--- a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -112,20 +112,21 @@
      * |QASYMM8        |QASYMM8            |S32    |QASYMM8        |
      * |QASYMM8_SIGNED |QASYMM8_SIGNED     |S32    |QASYMM8_SIGNED |
      *
-     * @param[in]  input   Source tensor. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
-     * @param[in]  weights Weights tensor. The weights must be 2 dimensional.
-     *                     If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions.
-     *                     If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension.
-     *                     Data type supported: Same as @p input.
-     * @param[in]  biases  Bias tensor. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED.
-     * @param[out] output  Destination tensor. Its shape should be equal to the output of a matrix multiplication between:
-     *                     - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer
-     *                     - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer.
-     *                     Data type supported: Same as @p input.
-     * @param[in]  fc_info (Optional) Fully connected layer additional info
+     * @param[in]  input        Source tensor. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
+     * @param[in]  weights      Weights tensor. The weights must be 2 dimensional.
+     *                          If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions.
+     *                          If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension.
+     *                          Data type supported: Same as @p input.
+     * @param[in]  biases       Bias tensor. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED.
+     * @param[out] output       Destination tensor. Its shape should be equal to the output of a matrix multiplication between:
+     *                          - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer
+     *                          - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer.
+     *                          Data type supported: Same as @p input.
+     * @param[in]  fc_info      (Optional) Fully connected layer additional info
+     * @param[in]  weights_info (Optional) Stores neccessary compute information when weights are already reshaped
      */
     void configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output,
-                   FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo());
+                   FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo());
     /** Static function to check if given info will lead to a valid configuration of @ref NEFullyConnectedLayer
      *
      * Similar to @ref NEFullyConnectedLayer
@@ -135,6 +136,21 @@
     static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
                            FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo());
 
+    /** Static function that queries whether fixed-format kernel exists for a given problem description
+     *
+     * @param[out] expected_weight_format Format in which weights should be for found fixed format kernel
+     * @param[in]  input                  Source tensor
+     * @param[in]  weights                Weights tensor.
+     * @param[in]  biases                 Bias tensor. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED.
+     * @param[in]  output                 Destination tensor
+     * @param[in]  fc_info                Fully connected layer additional info
+     * @param[in]  weights_info           Describes weights shape
+     *
+     * @return a status
+     */
+    static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *input, const ITensorInfo *weights,
+                               const ITensorInfo *biases, const ITensorInfo *output, const FullyConnectedLayerInfo &fc_info, const WeightsInfo &weights_info);
+
     //Inherited methods override
     void run() override;
     void prepare() override;
diff --git a/arm_compute/runtime/NEON/functions/NEGEMM.h b/arm_compute/runtime/NEON/functions/NEGEMM.h
index ce68a61..7ce2521 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMM.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMM.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -84,6 +84,15 @@
      */
     static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo());
 
+    /** Static function that queries whether there exists fixed-format kernel and if it exists it will return in the first argument in what format
+     * weights are expected to be reshaped as defined by WeightFormat class. Apart from the first argument the rest of the arguments are the same
+     * as in @ref NEGEMM::validate() except that all arguments are required.
+     *
+     * @return a status
+     */
+    static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output,
+                               float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo());
+
     // Inherited methods overridden:
     void run() override;
     void prepare() override;
diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp
index 6d77c61..3172644 100644
--- a/src/cpu/operators/CpuFullyConnected.cpp
+++ b/src/cpu/operators/CpuFullyConnected.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -53,7 +53,7 @@
 {
     PixelValue type_min{};
     PixelValue type_max{};
-    std::tie(type_min, type_max) = get_min_max(data_type);
+    std::tie(type_min, type_max)         = get_min_max(data_type);
     const UniformQuantizationInfo q_unif = q_info.uniform();
 
     if(act_info.enabled())
@@ -162,8 +162,9 @@
       _is_fc_after_conv(false),
       _is_quantized_asymmetric(false),
       _is_prepared(false),
-      _enable_fast_math(false)
-
+      _enable_fast_math(false),
+      _fixed_format(false),
+      _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
 {
 }
 
@@ -199,6 +200,8 @@
         GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */);
         gemm_info.set_activation_info(act);
         gemm_info.set_fast_math(_enable_fast_math);
+        gemm_info.set_fixed_format(_fixed_format);
+        gemm_info.set_weight_format(_weight_format);
         _mm_gemm = std::make_unique<CpuGemm>();
         _mm_gemm->configure(src, weights, biases, dst, 1.f, 1.0f, gemm_info);
     }
@@ -229,7 +232,7 @@
 }
 
 void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst,
-                                  FullyConnectedLayerInfo fc_info)
+                                  FullyConnectedLayerInfo fc_info, const WeightsInfo &weights_info)
 {
     // Perform validate step
     ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst);
@@ -248,6 +251,8 @@
     _is_prepared              = false;
     _trans_weights_idx        = AuxTensorIdx::Count;
     _enable_fast_math         = fc_info.enable_fast_math;
+    _fixed_format             = weights_info.weight_format() != WeightFormat::UNSPECIFIED;
+    _weight_format            = weights_info.weight_format();
 
     // With the Fully Connected layer we can have 4 different cases:
     //  1) Convolution layer -> Fully Connected layer without batches
@@ -261,9 +266,7 @@
     const bool is_batched_fc_layer = dst->dimension(1) > 1;
     if(is_batched_fc_layer)
     {
-        _is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3,
-                                                                                  src->tensor_shape().cend(),
-                                                                                  dst->tensor_shape().cbegin() + 1));
+        _is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3, src->tensor_shape().cend(), dst->tensor_shape().cbegin() + 1));
     }
     else
     {
@@ -323,12 +326,10 @@
     {
         // Release permuted weights at the end of prepare as they are further transposed by the assembly dispatch
         // Do not release them if biases are dynamic and data type is quantized, since the weights tensor will be used for biases offset calculation
-        _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), (_is_quantized_asymmetric
-                                                                                     && biases && !(biases->are_values_constant())) ?
-                                                 MemoryLifetime::Persistent :
-                                                 MemoryLifetime::Prepare,
+        _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), (_is_quantized_asymmetric && biases
+                                                                                     && !(biases->are_values_constant())) ? MemoryLifetime::Persistent : MemoryLifetime::Prepare,
                                                  _reshaped_weights.total_size());
-        _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size());
+        _aux_mem[ConvertedWeights]  = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size());
     }
     else
     {
@@ -338,6 +339,18 @@
     _aux_mem[FlattenedSrc] = MemoryInfo(offset_int_vec(FlattenedSrc), MemoryLifetime::Temporary, _flattened_src.total_size());
 }
 
+Status CpuFullyConnected::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights,
+                                       const ITensorInfo *biases, const ITensorInfo *dst, FullyConnectedLayerInfo fc_info, WeightsInfo weights_info)
+{
+    GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */);
+    gemm_info.set_activation_info(fc_info.activation_info);
+    gemm_info.set_fast_math(fc_info.enable_fast_math);
+    gemm_info.set_fixed_format(weights_info.weight_format() != WeightFormat::UNSPECIFIED);
+    gemm_info.set_weight_format(weights_info.weight_format());
+
+    return CpuGemm::has_opt_impl(expected_weight_format, src, weights, biases, dst, gemm_info);
+}
+
 Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
                                    FullyConnectedLayerInfo fc_info)
 {
@@ -384,9 +397,7 @@
 
     if(is_batched_fc_layer)
     {
-        is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3,
-                                                                                 src->tensor_shape().cend(),
-                                                                                 dst->tensor_shape().cbegin() + 1));
+        is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3, src->tensor_shape().cend(), dst->tensor_shape().cbegin() + 1));
     }
     else
     {
diff --git a/src/cpu/operators/CpuFullyConnected.h b/src/cpu/operators/CpuFullyConnected.h
index 44fa21f..36511e9 100644
--- a/src/cpu/operators/CpuFullyConnected.h
+++ b/src/cpu/operators/CpuFullyConnected.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -72,20 +72,21 @@
      * |QASYMM8        |QASYMM8            |S32    |QASYMM8        |
      * |QASYMM8_SIGNED |QASYMM8_SIGNED     |S32    |QASYMM8_SIGNED |
      *
-     * @param[in]  src     Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
-     * @param[in]  weights Weights tensor info. The weights must be 2 dimensional.
-     *                     If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions.
-     *                     If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension.
-     *                     Data type supported: Same as @p src.
-     * @param[in]  biases  Bias tensor info. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED.
-     * @param[out] dst     Destination tensor info. Its shape should be equal to the output of a matrix multiplication between:
-     *                     - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer
-     *                     - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer.
-     *                     Data type supported: Same as @p src.
-     * @param[in]  fc_info (Optional) Fully connected layer additional info
+     * @param[in]  src          Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
+     * @param[in]  weights      Weights tensor info. The weights must be 2 dimensional.
+     *                          If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions.
+     *                          If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension.
+     *                          Data type supported: Same as @p src.
+     * @param[in]  biases       Bias tensor info. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED.
+     * @param[out] dst          Destination tensor info. Its shape should be equal to the output of a matrix multiplication between:
+     *                          - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer
+     *                          - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer.
+     *                          Data type supported: Same as @p src.
+     * @param[in]  fc_info      (Optional) Fully connected layer additional info
+     * @param[in]  weights_info (Optional) Stores neccessary compute information when weights are already reshaped
      */
     void configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst,
-                   FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo());
+                   FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo());
     /** Static function to check if given info will lead to a valid configuration of @ref CpuFullyConnected
      *
      * Similar to @ref CpuFullyConnected
@@ -95,9 +96,19 @@
     static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
                            FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo());
 
+    /** Static function that queries whether there exists fixed-format kernel and if it exists it will return in the first argument in what format
+     * weights are expected to be reshaped as defined by WeightFormat class. Apart from the first argument the rest of the arguments are the same
+     * as in @ref CpuFullyConnectedLayer::validate() except that all arguments are required.
+     *
+     * @return a status
+     */
+    static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights,
+                               const ITensorInfo *biases, const ITensorInfo *dst,
+                               FullyConnectedLayerInfo fc_info, WeightsInfo weights_info);
+
     //Inherited methods override
-    void run(ITensorPack &tensors) override;
-    void prepare(ITensorPack &tensors) override;
+    void                             run(ITensorPack &tensors) override;
+    void                             prepare(ITensorPack &tensors) override;
     experimental::MemoryRequirements workspace() const override;
 
 private:
@@ -136,12 +147,14 @@
 
     experimental::MemoryRequirements _aux_mem;
 
-    bool _needs_weights_conversion;
-    bool _needs_weights_reshape;
-    bool _is_fc_after_conv;
-    bool _is_quantized_asymmetric;
-    bool _is_prepared;
-    bool _enable_fast_math;
+    bool                      _needs_weights_conversion;
+    bool                      _needs_weights_reshape;
+    bool                      _is_fc_after_conv;
+    bool                      _is_quantized_asymmetric;
+    bool                      _is_prepared;
+    bool                      _enable_fast_math;
+    bool                      _fixed_format;
+    arm_compute::WeightFormat _weight_format;
 };
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/operators/CpuGemmDirectConv2d.cpp b/src/cpu/operators/CpuGemmDirectConv2d.cpp
index fd1a042..ee47a17 100644
--- a/src/cpu/operators/CpuGemmDirectConv2d.cpp
+++ b/src/cpu/operators/CpuGemmDirectConv2d.cpp
@@ -57,11 +57,11 @@
                                                                                ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
                                                                                ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
                                                                              };
-    PixelValue type_min{};
-    PixelValue type_max{};
+    PixelValue                                              type_min{};
+    PixelValue                                              type_max{};
     std::tie(type_min, type_max) = get_min_max(data_type);
-    int32_t min_activation = type_min.get<int32_t>();
-    int32_t max_activation = type_max.get<int32_t>();
+    int32_t min_activation       = type_min.get<int32_t>();
+    int32_t max_activation       = type_max.get<int32_t>();
     if(supported_acts.count(act.activation()) != 0)
     {
         std::tie(min_activation, max_activation) = get_quantized_activation_min_max(act, data_type, uoqinfo);
@@ -88,6 +88,8 @@
     asm_info.padding_value           = 0.f;
     asm_info.negated_offsets         = false;
     asm_info.fast_mode               = info.enable_fast_math;
+    asm_info.fixed_format            = info.weights_info.weight_format() != WeightFormat::UNSPECIFIED;
+    asm_info.weight_format           = info.weights_info.weight_format();
     return asm_info;
 }
 } // namespace
@@ -146,7 +148,9 @@
     }
     else
     {
-        _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size());
+        // We must permute weights if they are WeightFormat::UNSPECIFIED
+        if(info.weights_info.weight_format() == WeightFormat::UNSPECIFIED)
+            _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size());
     }
 }
 Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info)
@@ -203,6 +207,13 @@
 {
     if(!_is_prepared)
     {
+        // If we are using fixed-format kernel the weights are already reshaped
+        if(_gemm_asm_func && _gemm_asm_func->isVarWeightsKernel())
+        {
+            _gemm_asm_func->prepare(tensors);
+            _is_prepared = true;
+            return;
+        }
         const ITensor *weights     = tensors.get_const_tensor(ACL_SRC_1);
         ITensor       *weights_aux = utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(PermutedWeights)));
         ARM_COMPUTE_ERROR_ON_NULLPTR(weights, weights_aux);
@@ -224,4 +235,4 @@
     return _aux_mem;
 }
 } // namespace cpu
-} // namespace arm_compute
\ No newline at end of file
+} // namespace arm_compute
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 45b3232..df02d64 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -156,8 +156,8 @@
                                                                                             const std::vector<int32_t> &multipliers);
 
     // Inherited methods overridden:
-    void run(ITensorPack &tensors) override;
-    void prepare(ITensorPack &tensors) override;
+    void                             run(ITensorPack &tensors) override;
+    void                             prepare(ITensorPack &tensors) override;
     bool                             is_configured() const override;
     experimental::MemoryRequirements workspace() const override;
     bool                             isVarWeightsKernel() const override
@@ -210,12 +210,12 @@
     /** Indirect buffer */
     std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
     std::unique_ptr<const TypeInput *, free_delete>        _indirect_buf{};
-    std::vector<TypeInput>           _indirect_pad{};
-    arm_gemm::ConvolutionParameters  _cp{};
-    experimental::MemoryRequirements _aux_mem{ Count };
-    bool                             _B_pretranspose_required{ false };
-    bool                             _is_b_constant{ true };
-    bool                             _is_c_constant{ true };
+    std::vector<TypeInput>                                 _indirect_pad{};
+    arm_gemm::ConvolutionParameters                        _cp{};
+    experimental::MemoryRequirements                       _aux_mem{ Count };
+    bool                                                   _B_pretranspose_required{ false };
+    bool                                                   _is_b_constant{ true };
+    bool                                                   _is_c_constant{ true };
 };
 
 template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -493,6 +493,7 @@
     if(!_gemm_kernel_asm->B_is_pretransposed())
     {
         ldb                                = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+        multi_stride_b                     = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
         const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
         if(is_fixed_format(wf))
         {
@@ -501,17 +502,35 @@
             // as a 2D tensor at arm_gemm level, where the rows are
             // O'/<interleave_by> and the columns are <interleave_by> *
             // H * W * I'.
-            ITensorInfo      *tensor_info   = b->info();
-            const DataLayout  data_layout   = tensor_info->data_layout();
-            const TensorShape tensor_shape  = tensor_info->tensor_shape();
-            const int         H             = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
-            const int         W             = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
-            const int         Ip            = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)];
-            const int         interleave_by = arm_compute::interleave_by(wf);
-            ldb                             = (interleave_by * H * W * Ip);
+            ITensorInfo      *tensor_info     = b->info();
+            const DataLayout  data_layout     = tensor_info->data_layout();
+            const TensorShape tensor_shape    = tensor_info->tensor_shape();
+            const int         tensor_height   = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
+            const int         tensor_width    = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
+            const int         tensor_channels = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)];
+            const int         interleave_by   = arm_compute::interleave_by(wf);
+            // We need to find a new stride that is distance from the data for one
+            // set of output channels to the next
+            if(ldb == tensor_channels && multi_stride_b == tensor_channels * tensor_width)
+            {
+                // In this case dimensions that are packed are height, width and channel
+                // so we need to stride it by interleave_by
+                ldb = interleave_by * tensor_height * tensor_width * tensor_channels;
+            }
+            else if(multi_stride_b == 0 || (ldb == tensor_width && multi_stride_b == tensor_height * tensor_width))
+            {
+                // In this case dimension that is packed is only height
+                // so we need to stride only height by interleave_by
+                ldb = interleave_by * tensor_height;
+            }
+            else
+            {
+                // If dimensions are not packed as above error is thrown
+                // as at the moment other forms of packing are not supported
+                ARM_COMPUTE_ERROR("Unsupported packing for fixed format kernel");
+            }
         }
-        multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
-        in1_ptr        = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
+        in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
     }
 
     // If necessary, run pretranspose every time if either weights or biases are non-constant
diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
index 77028d9..4f858fb 100644
--- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
+++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -61,7 +61,7 @@
 }
 
 void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output,
-                                      FullyConnectedLayerInfo fc_info)
+                                      FullyConnectedLayerInfo fc_info, const WeightsInfo &weights_info)
 {
     // Perform validate step
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
@@ -76,7 +76,7 @@
     _impl->original_weights = weights;
     _impl->is_prepared      = false;
 
-    _impl->op->configure(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), fc_info);
+    _impl->op->configure(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), fc_info, weights_info);
 
     if(_impl->weights_manager != nullptr)
     {
@@ -88,6 +88,13 @@
     _impl->workspace   = manage_workspace<Tensor>(_impl->aux_mem_req, _impl->memory_group, _impl->run_pack, _impl->run_pack);
 }
 
+Status NEFullyConnectedLayer::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *input, const ITensorInfo *weights,
+                                           const ITensorInfo *biases, const ITensorInfo *output, const FullyConnectedLayerInfo &fc_info,
+                                           const WeightsInfo &weights_info)
+{
+    return cpu::CpuFullyConnected::has_opt_impl(expected_weight_format, input, weights, biases, output, fc_info, weights_info);
+}
+
 Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
                                        FullyConnectedLayerInfo fc_info)
 {
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index 58ade9f..0266c48 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -84,6 +84,13 @@
     return cpu::CpuGemm::validate(a, b, c, output, alpha, beta, gemm_info);
 }
 
+Status NEGEMM::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output,
+                            float alpha, float beta, const GEMMInfo &gemm_info)
+{
+    ARM_COMPUTE_UNUSED(alpha, beta);
+    return cpu::CpuGemm::has_opt_impl(expected_weight_format, a, b, c, output, gemm_info);
+}
+
 void NEGEMM::run()
 {
     prepare();