COMPMID-1849: Implement CPPDetectionPostProcessLayer
* Add DetectionPostProcessLayer
* Add DetectionPostProcessLayer at the graph
Change-Id: I7e56f6cffc26f112d26dfe74853085bb8ec7d849
Signed-off-by: Isabella Gottardi <isabella.gottardi@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1639
Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp
index 54bd066..228f2d2 100644
--- a/src/graph/GraphBuilder.cpp
+++ b/src/graph/GraphBuilder.cpp
@@ -393,6 +393,36 @@
return detect_nid;
}
+NodeID GraphBuilder::add_detection_post_process_node(Graph &g, NodeParams params, NodeIdxPair input_box_encoding, NodeIdxPair input_class_prediction, const DetectionPostProcessLayerInfo &detect_info,
+ ITensorAccessorUPtr anchors_accessor, const QuantizationInfo &anchor_quant_info)
+{
+ check_nodeidx_pair(input_box_encoding, g);
+ check_nodeidx_pair(input_class_prediction, g);
+
+ // Get input tensor descriptor
+ const TensorDescriptor input_box_encoding_tensor_desc = get_tensor_descriptor(g, g.node(input_box_encoding.node_id)->outputs()[0]);
+
+ // Calculate anchor descriptor
+ TensorDescriptor anchor_desc = input_box_encoding_tensor_desc;
+ if(!anchor_quant_info.empty())
+ {
+ anchor_desc.quant_info = anchor_quant_info;
+ }
+
+ // Create anchors nodes
+ auto anchors_nid = add_const_node_with_name(g, params, "Anchors", anchor_desc, std::move(anchors_accessor));
+
+ // Create detection_output node and connect
+ NodeID detect_nid = g.add_node<DetectionPostProcessLayerNode>(detect_info);
+ g.add_connection(input_box_encoding.node_id, input_box_encoding.index, detect_nid, 0);
+ g.add_connection(input_class_prediction.node_id, input_class_prediction.index, detect_nid, 1);
+ g.add_connection(anchors_nid, 0, detect_nid, 2);
+
+ set_node_params(g, detect_nid, params);
+
+ return detect_nid;
+}
+
NodeID GraphBuilder::add_dummy_node(Graph &g, NodeParams params, NodeIdxPair input, TensorShape shape)
{
return create_simple_single_input_output_node<DummyNode>(g, params, input, shape);