COMPMID-1071: (3RDPARTY_UPDATE) Add depth multiplier on DepthwiseConv 3x3 NHWC
Change-Id: I316ff40dda379d4b84fac5d63f0c56efbacbc2b4
Reviewed-on: https://review.mlplatform.org/371
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h
index 78a3f20..d070331 100644
--- a/arm_compute/graph/frontend/Layers.h
+++ b/arm_compute/graph/frontend/Layers.h
@@ -414,24 +414,27 @@
public:
/** Construct a depthwise convolution layer.
*
- * @param[in] conv_width Convolution width.
- * @param[in] conv_height Convolution height.
- * @param[in] weights Accessor to get kernel weights from.
- * @param[in] bias Accessor to get kernel bias from.
- * @param[in] conv_info Padding and stride information.
- * @param[in] quant_info (Optional) Quantization info used for weights
+ * @param[in] conv_width Convolution width.
+ * @param[in] conv_height Convolution height.
+ * @param[in] weights Accessor to get kernel weights from.
+ * @param[in] bias Accessor to get kernel bias from.
+ * @param[in] conv_info Padding and stride information.
+ * @param[in] depth_multiplier (Optional) Depth multiplier parameter.
+ * @param[in] quant_info (Optional) Quantization info used for weights
*/
DepthwiseConvolutionLayer(unsigned int conv_width,
unsigned int conv_height,
ITensorAccessorUPtr weights,
ITensorAccessorUPtr bias,
PadStrideInfo conv_info,
- const QuantizationInfo quant_info = QuantizationInfo())
+ int depth_multiplier = 1,
+ const QuantizationInfo quant_info = QuantizationInfo())
: _conv_width(conv_width),
_conv_height(conv_height),
_conv_info(std::move(conv_info)),
_weights(std::move(weights)),
_bias(std::move(bias)),
+ _depth_multiplier(depth_multiplier),
_quant_info(std::move(quant_info))
{
}
@@ -441,7 +444,7 @@
NodeIdxPair input = { s.tail_node(), 0 };
NodeParams common_params = { name(), s.hints().target_hint };
return GraphBuilder::add_depthwise_convolution_node(s.graph(), common_params,
- input, Size2D(_conv_width, _conv_height), _conv_info,
+ input, Size2D(_conv_width, _conv_height), _conv_info, _depth_multiplier,
s.hints().depthwise_convolution_method_hint,
std::move(_weights), std::move(_bias), std::move(_quant_info));
}
@@ -452,6 +455,7 @@
const PadStrideInfo _conv_info;
ITensorAccessorUPtr _weights;
ITensorAccessorUPtr _bias;
+ int _depth_multiplier;
const QuantizationInfo _quant_info;
};