Add validate tests for CLConvolutionLayer and CLGEMMConvolutionLayer with post ops

* Add validate tests
* Restrict post ops support in ClGemmConv2d to only those that do not
  need im2col or col2im. In practice this means we only support post ops
  in conv1x1 with stride = 1, dilation = 1 and data layout = NHWC

Resolves COMPMID-4435

Change-Id: I1fdf0c5d565a4624857250075ac76db35c2f383b
Signed-off-by: SiCongLi <sicong.li@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6573
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/core/experimental/IPostOp.h b/arm_compute/core/experimental/IPostOp.h
index 178c83a..567a402 100644
--- a/arm_compute/core/experimental/IPostOp.h
+++ b/arm_compute/core/experimental/IPostOp.h
@@ -71,6 +71,15 @@
  *          * post_op_arg1 = [1, 1, 34] is allowed: broadcast in dims 0 and 1
  *          * post_op_arg1 = [14, 15, 34] is NOT allowed: broadcast widens the dst tensor
  *
+ * @note: On Data layout
+ *      All post ops are data layout agnostic. This means post ops do not have an inherent idea of "width", "height" and so on.
+ *      Should we want to perform a post op with 2 tensors of different data layouts (where data layouts are significant to both),
+ *      then we need to perform necessary permutation op beforehand to unify their data layout before they can be fused with a post op
+ *
+ *      Note although post ops themselves should be able to support any data layout, the main op they fuse to may impose
+ *      additional restrictions in the presence of post ops. For example, the implementation of a gemm op may only allow
+ *      NHWC data layout if post ops are provided. Such restrictions are main op implementation specific.
+ *
  *  @note: PostOps do not own any resources pointed to by TensorRelatedT if it's a pointer type
  *  @note: If TensorRelatedT points to a resource, IPostOp assumes that resource is valid throughout its lifetime
  *        and the lifetime of its copies. This is almost guaranteed as IPostOp is only meant to be used at configure time
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index f18f5b7..3e8b024 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -703,20 +703,18 @@
 
 /** Calculate the deep convolution shape output shape of a tensor
  *
- * @param[in] input     Input tensor info
- * @param[in] weights   Weights tensor info
- * @param[in] conv_info Contains padding and stride information
+ * @param[in] input_shape       Input tensor shape
+ * @param[in] input_data_layout Input data layout
+ * @param[in] weights_shape     Weights tensor shape
+ * @param[in] conv_info         Contains padding and stride information
  *
  * @return the calculated shape
  */
-inline TensorShape compute_deep_convolution_shape(const ITensorInfo &input, const ITensorInfo &weights, PadStrideInfo conv_info)
+inline TensorShape compute_deep_convolution_shape(const TensorShape &input_shape, DataLayout input_data_layout, const TensorShape &weights_shape, const PadStrideInfo &conv_info)
 {
-    const TensorShape input_shape{ input.tensor_shape() };
-    const TensorShape weights_shape{ weights.tensor_shape() };
-
-    const size_t idx_width   = get_data_layout_dimension_index(input.data_layout(), DataLayoutDimension::WIDTH);
-    const size_t idx_height  = get_data_layout_dimension_index(input.data_layout(), DataLayoutDimension::HEIGHT);
-    const size_t idx_channel = get_data_layout_dimension_index(input.data_layout(), DataLayoutDimension::CHANNEL);
+    const size_t idx_width   = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::WIDTH);
+    const size_t idx_height  = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::HEIGHT);
+    const size_t idx_channel = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::CHANNEL);
 
     const unsigned int input_width         = input_shape[idx_width];
     const unsigned int input_height        = input_shape[idx_height];
@@ -735,6 +733,19 @@
     return output_shape;
 }
 
+/** Calculate the deep convolution shape output shape of a tensor
+ *
+ * @param[in] input     Input tensor info
+ * @param[in] weights   Weights tensor info
+ * @param[in] conv_info Contains padding and stride information
+ *
+ * @return the calculated shape
+ */
+inline TensorShape compute_deep_convolution_shape(const ITensorInfo &input, const ITensorInfo &weights, const PadStrideInfo &conv_info)
+{
+    return compute_deep_convolution_shape(input.tensor_shape(), input.data_layout(), weights.tensor_shape(), conv_info);
+}
+
 /** Calculate the min/max shape output shape of a tensor
  *
  * @param[in] input Input tensor info
diff --git a/src/core/CL/CLUtils.cpp b/src/core/CL/CLUtils.cpp
index 88b31c8..8dab8aa 100644
--- a/src/core/CL/CLUtils.cpp
+++ b/src/core/CL/CLUtils.cpp
@@ -96,9 +96,9 @@
                 return false;
             }
             // NOTE: Kernel limitation: currently only the following broadcasting types are supported:
-            //  1. Post op arg is scalar, broadcast in both X and Y
-            //  2. Post op arg is of shape: Y=1, X=N, broadcast only in Y
-            //  This means this case: Post op arg is of shape: Y=M, X=1, broadcast only in X, is NOT supported
+            //  1. Post op arg is scalar, broadcast in both first and second dims
+            //  2. Post op arg is of shape: second dim=1, first dim=N, broadcast only in second dim
+            //  This means this case: Post op arg is of shape: second dim=M, first dim=1, broadcast only in first dim, is NOT supported
             if(dst->dimension(0) > 1 && dst->dimension(1) > 1 && (*tensor)->dimension(0) == 1 && (*tensor)->dimension(1) > 1)
             {
                 return false;
diff --git a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl
index bbe97b2..4665d61 100644
--- a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl
+++ b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl
@@ -133,7 +133,7 @@
                                                      IMAGE_DECLARATION(bias),
 #endif // defined(BETA)
                                                      IMAGE_DECLARATION(dst),
-                                                     // Post-Op arguments
+                                                     // Post Op arguments
                                                      IMAGE_DECLARATION(eltwise_operand),
                                                      uint lhs_stride_z,
                                                      uint rhs_stride_z,
diff --git a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl
index 9e9a73c..32186c3 100644
--- a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl
+++ b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl
@@ -233,7 +233,7 @@
                                                                     IMAGE_DECLARATION(bias),
 #endif // defined(BETA)
                                                                     IMAGE_DECLARATION(dst),
-                                                                    // Post-Op arguments
+                                                                    // Post Op arguments
                                                                     IMAGE_DECLARATION(eltwise_operand),
                                                                     uint k,
                                                                     uint lhs_stride_z,
@@ -453,7 +453,7 @@
                                                                             IMAGE_DECLARATION(bias),
 #endif // defined(BETA)
                                                                             IMAGE_DECLARATION(dst),
-                                                                            // Post-Op arguments
+                                                                            // Post Op arguments
                                                                             IMAGE_DECLARATION(eltwise_operand),
                                                                             uint k,
                                                                             uint lhs_stride_z,
@@ -781,7 +781,7 @@
                                                                     IMAGE_DECLARATION(bias),
 #endif // defined(BETA)
                                                                     IMAGE_DECLARATION(dst),
-                                                                    // Post-Op arguments
+                                                                    // Post Op arguments
                                                                     IMAGE_DECLARATION(eltwise_operand),
                                                                     uint k,
                                                                     uint lhs_stride_z,
@@ -1110,7 +1110,7 @@
                                                                             IMAGE_DECLARATION(bias),
 #endif // defined(BETA)
                                                                             IMAGE_DECLARATION(dst),
-                                                                            // Post-Op arguments
+                                                                            // Post Op arguments
                                                                             IMAGE_DECLARATION(eltwise_operand),
                                                                             uint k,
                                                                             uint lhs_stride_z,
diff --git a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped_only_rhs.cl b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped_only_rhs.cl
index fe2d103..e96aba6 100644
--- a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped_only_rhs.cl
+++ b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped_only_rhs.cl
@@ -177,7 +177,7 @@
                                                                   IMAGE_DECLARATION(bias),
 #endif // defined(BETA)
                                                                   IMAGE_DECLARATION(dst),
-                                                                  // Post-Op arguments
+                                                                  // Post Op arguments
                                                                   IMAGE_DECLARATION(eltwise_operand),
                                                                   uint lhs_stride_z,
                                                                   uint rhs_stride_z,
@@ -437,7 +437,7 @@
                                                                           IMAGE_DECLARATION(bias),
 #endif // defined(BETA)
                                                                           IMAGE_DECLARATION(dst),
-                                                                          // Post-Op arguments
+                                                                          // Post Op arguments
                                                                           IMAGE_DECLARATION(eltwise_operand),
                                                                           uint lhs_stride_z,
                                                                           uint rhs_stride_z,
@@ -831,7 +831,7 @@
                                                                    IMAGE_DECLARATION(bias),
 #endif // defined(BETA)
                                                                    IMAGE_DECLARATION(dst),
-                                                                   // Post-Op arguments
+                                                                   // Post Op arguments
                                                                    IMAGE_DECLARATION(eltwise_operand),
                                                                    uint lhs_stride_z,
                                                                    uint rhs_stride_z,
@@ -1116,7 +1116,7 @@
                                                                            IMAGE_DECLARATION(bias),
 #endif // defined(BETA)
                                                                            IMAGE_DECLARATION(dst),
-                                                                           // Post-Op arguments
+                                                                           // Post Op arguments
                                                                            IMAGE_DECLARATION(eltwise_operand),
                                                                            uint lhs_stride_z,
                                                                            uint rhs_stride_z,
diff --git a/src/gpu/cl/operators/ClGemmConv2d.cpp b/src/gpu/cl/operators/ClGemmConv2d.cpp
index 7db5fa0..682477e 100644
--- a/src/gpu/cl/operators/ClGemmConv2d.cpp
+++ b/src/gpu/cl/operators/ClGemmConv2d.cpp
@@ -389,6 +389,9 @@
 
     ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(idx_channel) * conv2d_info.num_groups) != src->dimension(idx_channel));
     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(!skip_im2col
+                                    && conv2d_info.post_ops.size() > 0,
+                                    "ClGemmConv2d does not support post ops with col2im or im2col operation"); // Post ops must be performed after every other op
 
     // Validate biases
     if(biases != nullptr)
@@ -523,7 +526,6 @@
     // Validate Col2Im
     if(!skip_col2im)
     {
-        ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv2d_info.post_ops.size() > 0, "ClGemmConv2d does not support post ops with col2im operation"); // Post ops must be performed after every other op
         ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClCol2ImKernel::validate(gemm_output_to_use, dst, Size2D(conv_w, conv_h), conv2d_info.num_groups));
     }
 
diff --git a/tests/validation/CL/ConvolutionLayer.cpp b/tests/validation/CL/ConvolutionLayer.cpp
index ae2949c..ff28ac0 100644
--- a/tests/validation/CL/ConvolutionLayer.cpp
+++ b/tests/validation/CL/ConvolutionLayer.cpp
@@ -22,10 +22,12 @@
  * SOFTWARE.
  */
 #include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/runtime/CL/CLTensor.h"
 #include "arm_compute/runtime/CL/CLTensorAllocator.h"
 #include "arm_compute/runtime/CL/functions/CLConvolutionLayer.h"
 #include "arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h"
+#include "src/core/experimental/PostOp.h"
 #include "tests/CL/CLAccessor.h"
 #include "tests/PaddingCalculator.h"
 #include "tests/datasets/LargeConvolutionLayerDataset.h"
@@ -88,6 +90,29 @@
     ActivationLayerInfo(),
     ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 0.5f)
 });
+
+bool is_post_op_list_valid_in_gemmconv(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &output_shape, DataType data_type, DataLayout data_layout,
+                                       const PadStrideInfo &conv_info, const experimental::PostOpList<ITensorInfo *> &post_ops)
+{
+    const int idx_width   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+    const int idx_height  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+    const int idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
+
+    const auto         dilation   = Size2D(1U, 1U);
+    const unsigned int num_groups = 1U;
+
+    TensorInfo input_info(input_shape, 1, data_type, data_layout);
+    TensorInfo weights_info(weights_shape, 1, data_type, data_layout);
+
+    TensorInfo output_info(output_shape, 1, data_type, data_layout);
+
+    WeightsInfo w_info(false, weights_info.dimension(idx_width), weights_info.dimension(idx_height), weights_info.dimension(idx_kernels));
+
+    const auto status = CLGEMMConvolutionLayer::validate(&input_info.clone()->set_is_resizable(true),
+                                                         &weights_info.clone()->set_is_resizable(true), nullptr, &output_info.clone()->set_is_resizable(true),
+                                                         conv_info, w_info, dilation, ActivationLayerInfo(), num_groups, post_ops);
+    return bool(status);
+}
 } // namespace
 
 TEST_SUITE(CL)
@@ -179,6 +204,72 @@
                                                                             enable_fast_math);
     ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS);
 }
+
+DATA_TEST_CASE(ValidatePostOpSupportInConvolutionMethod, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(
+                                          framework::dataset::make("InputInfo", { TensorInfo(TensorShape(2U, 17U, 31U), 1, DataType::F32, DataLayout::NHWC),            // Select GEMM
+                                                                                  TensorInfo(TensorShape(17U, 31U, 32U), 1, DataType::F32, DataLayout::NCHW),           // Select WINOGRAD
+                                                                                  TensorInfo(TensorShape(27U, 27U, 48U), 1, DataType::F32, DataLayout::NCHW),           // Select Direct
+                                                                                  TensorInfo(TensorShape(27U, 27U, 48U), 1, DataType::F32, DataLayout::NCHW),           // Select FFT
+                                          }),
+                                          framework::dataset::make("WeightsInfo", { TensorInfo(TensorShape(2U, 1U, 1U, 19U), 1, DataType::F32, DataLayout::NHWC),
+                                                                                    TensorInfo(TensorShape(5U, 5U, 32U, 19U), 1, DataType::F32, DataLayout::NCHW),
+                                                                                    TensorInfo(TensorShape(5U, 5U, 48U, 128U), 1, DataType::F32, DataLayout::NCHW),
+                                                                                    TensorInfo(TensorShape(11U, 11U, 48U, 24), 1, DataType::F32, DataLayout::NCHW),
+                                          })),
+                                          framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(19U, 17U, 31U), 1, DataType::F32, DataLayout::NHWC),
+                                                                                   TensorInfo(TensorShape(17U, 31U, 19U), 1, DataType::F32, DataLayout::NCHW),
+                                                                                   TensorInfo(TensorShape(27U, 27U, 128U), 1, DataType::F32, DataLayout::NCHW),
+                                                                                   TensorInfo(TensorShape(27U, 27U, 24U), 1, DataType::F32, DataLayout::NCHW),
+                                          })),
+                                          framework::dataset::make("ConvInfo", { PadStrideInfo(1U, 1U, 0U, 0U),
+                                                                                 PadStrideInfo(1U, 1U, 2U, 2U),
+                                                                                 PadStrideInfo(1U, 1U, 2U, 2U),
+                                                                                 PadStrideInfo(1U, 1U, 5U, 5U),
+                                          })),
+                                         framework::dataset::make("EnableFastMath", { false, true, false, false})),
+                                         framework::dataset::make("ExpectedMethod",{ ConvolutionMethod::GEMM,
+                                                                                     ConvolutionMethod::WINOGRAD,
+                                                                                     ConvolutionMethod::DIRECT,
+                                                                                     ConvolutionMethod::FFT,
+                                         })),
+                                         framework::dataset::make("PostOpSupported",{ true, false, false, false
+                                         })),
+                                         input_info, weights_info, output_info, conv_info, enable_fast_math, expected_method, post_op_supported)
+{
+    const int idx_width  = get_data_layout_dimension_index(input_info.data_layout(), DataLayoutDimension::WIDTH);
+    const int idx_height = get_data_layout_dimension_index(input_info.data_layout(), DataLayoutDimension::HEIGHT);
+    const int idx_kernels = get_data_layout_dimension_index(input_info.data_layout(), DataLayoutDimension::BATCHES);
+
+    const auto dilation = Size2D(1U, 1U);
+    const unsigned int num_groups = 1U;
+
+    WeightsInfo w_info(false, weights_info.dimension(idx_width), weights_info.dimension(idx_height), weights_info.dimension(idx_kernels));
+
+    experimental::PostOpList<ITensorInfo*> post_ops{};
+    post_ops.push_back_op<experimental::PostOpAct<ITensorInfo*>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+
+    ConvolutionMethod actual_method = CLConvolutionLayer::get_convolution_method(&input_info.clone()->set_is_resizable(true),
+                                                                            &weights_info.clone()->set_is_resizable(true),
+                                                                            &output_info.clone()->set_is_resizable(true), conv_info,
+                                                                            WeightsInfo(),
+                                                                            ActivationLayerInfo(),
+                                                                            GPUTarget::BIFROST,
+                                                                            dilation,
+                                                                            enable_fast_math);
+    ARM_COMPUTE_EXPECT(actual_method == expected_method, framework::LogLevel::ERRORS);
+    const auto is_valid = CLConvolutionLayer::validate(&input_info.clone()->set_is_resizable(true),
+                                                                            &weights_info.clone()->set_is_resizable(true),
+                                                                            nullptr,
+                                                                            &output_info.clone()->set_is_resizable(true),
+                                                                            conv_info,
+                                                                            w_info,
+                                                                            dilation,
+                                                                            ActivationLayerInfo(),
+                                                                            enable_fast_math,
+                                                                            num_groups,
+                                                                            post_ops);
+    ARM_COMPUTE_EXPECT( bool(is_valid) == post_op_supported, framework::LogLevel::ERRORS);
+}
 // clang-format on
 // *INDENT-ON*
 TEST_SUITE_END() // ConvolutionLayer
@@ -191,6 +282,159 @@
 template <typename T>
 using CLConvolutionValidationWithPaddingFixture = ConvolutionValidationWithPaddingFixture<CLTensor, CLAccessor, CLGEMMConvolutionLayer, T>;
 
+TEST_SUITE(ValidateFusedPostOpsConfigs)
+TEST_SUITE(Invalid)
+TEST_CASE(UnsupportedPostOpSequence, framework::DatasetMode::ALL)
+{
+    const auto data_type     = DataType::F32;
+    const auto data_layout   = DataLayout::NHWC;
+    const auto conv_info     = PadStrideInfo(1, 1, 0, 0);
+    const auto input_shape   = TensorShape(16U, 14U, 12U, 2U);
+    const auto weights_shape = TensorShape(16U, 1U, 1U, 24U);
+
+    const auto output_shape = misc::shape_calculator::compute_deep_convolution_shape(input_shape, data_layout, weights_shape, conv_info);
+
+    const TensorShape post_op_arg0_shape(output_shape);
+    TensorInfo        post_op_arg_info(post_op_arg0_shape, 1, data_type);
+    auto              post_op_arg1_info = post_op_arg_info.clone();
+
+    // Unsupported sequence of post ops
+    experimental::PostOpList<ITensorInfo *> post_ops{};
+    post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo *>>(
+                                                                          &post_op_arg_info,
+                                                                          1,
+                                                                          ConvertPolicy::SATURATE);
+    post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo *>>(
+                                                                          post_op_arg1_info.get(),
+                                                                          0,
+                                                                          ConvertPolicy::SATURATE);
+
+    ARM_COMPUTE_EXPECT(is_post_op_list_valid_in_gemmconv(input_shape, weights_shape, output_shape, data_type, data_layout, conv_info, post_ops) == false, framework::LogLevel::ERRORS);
+}
+TEST_CASE(OnlyNHWCIsSupported, framework::DatasetMode::ALL)
+{
+    const auto data_type     = DataType::F32;
+    const auto data_layout   = DataLayout::NCHW;
+    const auto conv_info     = PadStrideInfo(1, 1, 0, 0);
+    const auto input_shape   = TensorShape(14U, 12U, 16U, 2U);
+    const auto weights_shape = TensorShape(1U, 1U, 16U, 24U);
+
+    const auto output_shape = misc::shape_calculator::compute_deep_convolution_shape(input_shape, data_layout, weights_shape, conv_info);
+
+    const TensorShape post_op_arg0_shape(output_shape);
+    TensorInfo        post_op_arg_info(post_op_arg0_shape, 1, data_type);
+
+    experimental::PostOpList<ITensorInfo *> post_ops{};
+    post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo *>>(
+                                                                          &post_op_arg_info,
+                                                                          1,
+                                                                          ConvertPolicy::SATURATE);
+
+    ARM_COMPUTE_EXPECT(is_post_op_list_valid_in_gemmconv(input_shape, weights_shape, output_shape, data_type, data_layout, conv_info, post_ops) == false, framework::LogLevel::ERRORS);
+}
+TEST_CASE(OnlyFloatingTypeIsSupported, framework::DatasetMode::ALL)
+{
+    const auto data_type     = DataType::QASYMM8;
+    const auto data_layout   = DataLayout::NHWC;
+    const auto conv_info     = PadStrideInfo(1, 1, 0, 0);
+    const auto input_shape   = TensorShape(16U, 14U, 12U, 2U);
+    const auto weights_shape = TensorShape(16U, 1U, 1U, 24U);
+
+    const auto output_shape = misc::shape_calculator::compute_deep_convolution_shape(input_shape, data_layout, weights_shape, conv_info);
+
+    const TensorShape post_op_arg0_shape(output_shape);
+    TensorInfo        post_op_arg_info(post_op_arg0_shape, 1, data_type);
+
+    experimental::PostOpList<ITensorInfo *> post_ops{};
+    post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo *>>(
+                                                                          &post_op_arg_info,
+                                                                          1,
+                                                                          ConvertPolicy::SATURATE);
+
+    ARM_COMPUTE_EXPECT(is_post_op_list_valid_in_gemmconv(input_shape, weights_shape, output_shape, data_type, data_layout, conv_info, post_ops) == false, framework::LogLevel::ERRORS);
+}
+TEST_CASE(OnlyConv1x1Stride1IsSupported_UnsupportedKernelSize, framework::DatasetMode::ALL)
+{
+    const auto data_type     = DataType::F32;
+    const auto data_layout   = DataLayout::NHWC;
+    const auto conv_info     = PadStrideInfo(1, 1, 0, 0);
+    const auto input_shape   = TensorShape(16U, 14U, 12U, 2U);
+    const auto weights_shape = TensorShape(16U, 3U, 3U, 24U);
+
+    const auto output_shape = misc::shape_calculator::compute_deep_convolution_shape(input_shape, data_layout, weights_shape, conv_info);
+
+    const TensorShape post_op_arg0_shape(output_shape);
+    TensorInfo        post_op_arg_info(post_op_arg0_shape, 1, data_type);
+
+    experimental::PostOpList<ITensorInfo *> post_ops{};
+    post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo *>>(
+                                                                          &post_op_arg_info,
+                                                                          1,
+                                                                          ConvertPolicy::SATURATE);
+
+    ARM_COMPUTE_EXPECT(is_post_op_list_valid_in_gemmconv(input_shape, weights_shape, output_shape, data_type, data_layout, conv_info, post_ops) == false, framework::LogLevel::ERRORS);
+}
+TEST_CASE(OnlyConv1x1Stride1IsSupported_UnsupportedStride, framework::DatasetMode::ALL)
+{
+    const auto data_type     = DataType::F32;
+    const auto data_layout   = DataLayout::NHWC;
+    const auto conv_info     = PadStrideInfo(3, 3, 0, 0);
+    const auto input_shape   = TensorShape(16U, 14U, 12U, 2U);
+    const auto weights_shape = TensorShape(16U, 1U, 1U, 24U);
+
+    const auto output_shape = misc::shape_calculator::compute_deep_convolution_shape(input_shape, data_layout, weights_shape, conv_info);
+
+    const TensorShape post_op_arg0_shape(output_shape);
+    TensorInfo        post_op_arg_info(post_op_arg0_shape, 1, data_type);
+
+    experimental::PostOpList<ITensorInfo *> post_ops{};
+    post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo *>>(
+                                                                          &post_op_arg_info,
+                                                                          1,
+                                                                          ConvertPolicy::SATURATE);
+
+    ARM_COMPUTE_EXPECT(is_post_op_list_valid_in_gemmconv(input_shape, weights_shape, output_shape, data_type, data_layout, conv_info, post_ops) == false, framework::LogLevel::ERRORS);
+}
+TEST_SUITE_END() // Invalid
+TEST_SUITE(Valid)
+TEST_CASE(EmptyPostOpList, framework::DatasetMode::ALL)
+{
+    const auto data_type     = DataType::F32;
+    const auto data_layout   = DataLayout::NHWC;
+    const auto conv_info     = PadStrideInfo(1, 1, 0, 0);
+    const auto input_shape   = TensorShape(16U, 14U, 12U, 2U);
+    const auto weights_shape = TensorShape(16U, 1U, 1U, 24U);
+
+    const auto output_shape = misc::shape_calculator::compute_deep_convolution_shape(input_shape, data_layout, weights_shape, conv_info);
+
+    experimental::PostOpList<ITensorInfo *> post_ops{};
+
+    ARM_COMPUTE_EXPECT(is_post_op_list_valid_in_gemmconv(input_shape, weights_shape, output_shape, data_type, data_layout, conv_info, post_ops) == true, framework::LogLevel::ERRORS);
+}
+TEST_CASE(SupportedPostOps, framework::DatasetMode::ALL)
+{
+    const auto data_type     = DataType::F32;
+    const auto data_layout   = DataLayout::NHWC;
+    const auto conv_info     = PadStrideInfo(1, 1, 0, 0);
+    const auto input_shape   = TensorShape(16U, 14U, 12U, 2U);
+    const auto weights_shape = TensorShape(16U, 1U, 1U, 24U);
+
+    const auto output_shape = misc::shape_calculator::compute_deep_convolution_shape(input_shape, data_layout, weights_shape, conv_info);
+
+    TensorShape post_op_arg0_shape(output_shape);
+    post_op_arg0_shape[1] = 1; // Broadcast in "Y" (second) dimension
+    TensorInfo post_op_arg_info(post_op_arg0_shape, 1, data_type);
+
+    experimental::PostOpList<ITensorInfo *> post_ops{};
+    post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo *>>(
+                                                                          &post_op_arg_info,
+                                                                          1,
+                                                                          ConvertPolicy::SATURATE);
+
+    ARM_COMPUTE_EXPECT(is_post_op_list_valid_in_gemmconv(input_shape, weights_shape, output_shape, data_type, data_layout, conv_info, post_ops) == true, framework::LogLevel::ERRORS);
+}
+TEST_SUITE_END() // Valid
+TEST_SUITE_END() // ValidateFusedPostOps
 TEST_SUITE(Float)
 TEST_SUITE(FP16)
 
diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h
index 785b41f..9858478 100644
--- a/utils/TypePrinter.h
+++ b/utils/TypePrinter.h
@@ -2180,6 +2180,12 @@
         case ConvolutionMethod::WINOGRAD:
             os << "WINOGRAD";
             break;
+        case ConvolutionMethod::FFT:
+            os << "FFT";
+            break;
+        case ConvolutionMethod::GEMM_CONV2D:
+            os << "GEMM_CONV2D";
+            break;
         default:
             ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
     }