COMPMID-3221: Add DeconvolutionLayerDescriptor
A new struct for DeconvolutionLayerNode is added for better
extendability.
Change-Id: I935277e8073a8295de7b0059b946cb637085f1ff
Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2883
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp
index e429817..218e6ce 100644
--- a/src/graph/GraphBuilder.cpp
+++ b/src/graph/GraphBuilder.cpp
@@ -306,7 +306,7 @@
}
// Create convolution node and connect
- NodeID deconv_nid = g.add_node<DeconvolutionLayerNode>(deconv_info);
+ NodeID deconv_nid = g.add_node<DeconvolutionLayerNode>(descriptors::DeconvolutionLayerDescriptor{ deconv_info });
g.add_connection(input.node_id, input.index, deconv_nid, 0);
g.add_connection(w_nid, 0, deconv_nid, 1);
if(has_bias)
diff --git a/src/graph/nodes/DeconvolutionLayerNode.cpp b/src/graph/nodes/DeconvolutionLayerNode.cpp
index a2e4e2b..2daeaac 100644
--- a/src/graph/nodes/DeconvolutionLayerNode.cpp
+++ b/src/graph/nodes/DeconvolutionLayerNode.cpp
@@ -32,8 +32,8 @@
{
namespace graph
{
-DeconvolutionLayerNode::DeconvolutionLayerNode(PadStrideInfo info, QuantizationInfo out_quant_info)
- : _info(std::move(info)), _out_quant_info(std::move(out_quant_info))
+DeconvolutionLayerNode::DeconvolutionLayerNode(const descriptors::DeconvolutionLayerDescriptor &descriptor)
+ : descriptor(std::move(descriptor))
{
_input_edges.resize(3, EmptyEdgeID);
_outputs.resize(1, NullTensorID);
@@ -41,7 +41,7 @@
PadStrideInfo DeconvolutionLayerNode::deconvolution_info() const
{
- return _info;
+ return descriptor.info;
}
TensorDescriptor DeconvolutionLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor,
@@ -87,11 +87,11 @@
ARM_COMPUTE_ERROR_ON(src == nullptr || weights == nullptr);
- TensorDescriptor output_info = compute_output_descriptor(src->desc(), weights->desc(), _info);
+ TensorDescriptor output_info = compute_output_descriptor(src->desc(), weights->desc(), descriptor.info);
- if(!_out_quant_info.empty())
+ if(!descriptor.out_quant_info.empty())
{
- output_info.set_quantization_info(_out_quant_info);
+ output_info.set_quantization_info(descriptor.out_quant_info);
}
return output_info;