COMPMID-1505: Add native grouping support at graph level
Change-Id: Iedc91b0aee743b59af5140c8acb8124548da3163
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/144362
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com>
diff --git a/arm_compute/graph/nodes/ConvolutionLayerNode.h b/arm_compute/graph/nodes/ConvolutionLayerNode.h
index 4299be6..0698ac1 100644
--- a/arm_compute/graph/nodes/ConvolutionLayerNode.h
+++ b/arm_compute/graph/nodes/ConvolutionLayerNode.h
@@ -37,11 +37,13 @@
/** Constructor
*
* @param[in] info Convolution layer attributes
+ * @param[in] num_groups (Optional) Number of groups (Defaults to 1)
* @param[in] method (Optional) Convolution method to use
* @param[in] fast_math_hint (Optional) Fast math hint
* @param[in] out_quant_info (Optional) Output quantization info
*/
ConvolutionLayerNode(PadStrideInfo info,
+ unsigned int num_groups = 1,
ConvolutionMethod method = ConvolutionMethod::Default,
FastMathHint fast_math_hint = FastMathHint::Disabled,
QuantizationInfo out_quant_info = QuantizationInfo());
@@ -73,6 +75,11 @@
* @return Convolution information
*/
PadStrideInfo convolution_info() const;
+ /** Number of groups in convolution accessor
+ *
+ * @return Number of groups in convolution
+ */
+ unsigned int num_groups() const;
/** Computes convolution output descriptor
*
* @param[in] input_descriptor Input descriptor
@@ -93,6 +100,7 @@
private:
PadStrideInfo _info;
+ unsigned int _num_groups;
ConvolutionMethod _method;
FastMathHint _fast_math_hint;
QuantizationInfo _out_quant_info;