blob: f71f72af3add08fc92a51597a75bd432f5b5e00f [file] [log] [blame]
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +00001//
Declan-ARM7c75e332024-03-12 16:40:25 +00002// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "DetectionPostProcessLayer.hpp"
7
8#include "LayerCloneBase.hpp"
9
10#include <armnn/TypesUtils.hpp>
Colm Donelan0c479742021-12-10 12:43:54 +000011#include <armnn/backends/TensorHandle.hpp>
12#include <armnn/backends/WorkloadData.hpp>
13#include <armnn/backends/WorkloadFactory.hpp>
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +000014
15namespace armnn
16{
17
18DetectionPostProcessLayer::DetectionPostProcessLayer(const DetectionPostProcessDescriptor& param, const char* name)
19 : LayerWithParameters(2, 4, LayerType::DetectionPostProcess, param, name)
20{
21}
22
Derek Lamberti94a88d22019-12-10 21:12:59 +000023std::unique_ptr<IWorkload> DetectionPostProcessLayer::CreateWorkload(const armnn::IWorkloadFactory& factory) const
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +000024{
25 DetectionPostProcessQueueDescriptor descriptor;
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +000026 descriptor.m_Anchors = m_Anchors.get();
Keith Davisdf04d232020-10-23 17:20:05 +010027 SetAdditionalInfo(descriptor);
28
Teresa Charlin611c7fb2022-01-07 09:47:29 +000029 return factory.CreateWorkload(LayerType::DetectionPostProcess, descriptor, PrepInfoAndDesc(descriptor));
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +000030}
31
32DetectionPostProcessLayer* DetectionPostProcessLayer::Clone(Graph& graph) const
33{
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +000034 auto layer = CloneBase<DetectionPostProcessLayer>(graph, m_Param, GetName());
Finn Williams4422cec2021-03-22 17:51:06 +000035 layer->m_Anchors = m_Anchors ? m_Anchors : nullptr;
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +000036 return std::move(layer);
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +000037}
38
Finn Williamsf24effa2020-07-03 10:12:03 +010039void DetectionPostProcessLayer::ValidateTensorShapesFromInputs()
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +000040{
41 VerifyLayerConnections(2, CHECK_LOCATION());
Narumol Prangnawarata0d56c72019-01-25 10:46:40 +000042
Finn Williams87d0bda2020-07-03 10:12:03 +010043 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
44
Finn Williamsf24effa2020-07-03 10:12:03 +010045 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
Finn Williams87d0bda2020-07-03 10:12:03 +010046
Narumol Prangnawarata0d56c72019-01-25 10:46:40 +000047 // on this level constant data should not be released.
Declan-ARM7c75e332024-03-12 16:40:25 +000048 if (!m_Anchors)
49 {
50 throw armnn::LayerValidationException("DetectionPostProcessLayer: Anchors data should not be null.");
51 }
Narumol Prangnawarata0d56c72019-01-25 10:46:40 +000052
Declan-ARM7c75e332024-03-12 16:40:25 +000053 if (GetNumOutputSlots() != 4)
54 {
55 throw armnn::LayerValidationException("DetectionPostProcessLayer: The layer should return 4 outputs.");
56 }
Narumol Prangnawarata0d56c72019-01-25 10:46:40 +000057
Mike Kelly377fb212023-01-10 15:55:28 +000058 std::vector<TensorShape> inferredShapes = InferOutputShapes(
Mike Kellya9ac6ba2023-06-30 15:18:26 +010059 { GetInputSlot(0).GetTensorInfo().GetShape(),
60 GetInputSlot(1).GetTensorInfo().GetShape() });
Narumol Prangnawarata0d56c72019-01-25 10:46:40 +000061
Declan-ARM7c75e332024-03-12 16:40:25 +000062 if (inferredShapes.size() != 4)
63 {
64 throw armnn::LayerValidationException("inferredShapes has "
65 + std::to_string(inferredShapes.size()) +
66 " element(s) - should only have 4.");
67 }
68
69 if (std::any_of(inferredShapes.begin(), inferredShapes.end(), [] (auto&& inferredShape) {
70 return inferredShape.GetDimensionality() != Dimensionality::Specified;
71 }))
72 {
73 throw armnn::Exception("One of inferredShapes' dimensionalities is not specified.");
74 }
Narumol Prangnawarata0d56c72019-01-25 10:46:40 +000075
Mike Kelly377fb212023-01-10 15:55:28 +000076 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "DetectionPostProcessLayer");
Finn Williams87d0bda2020-07-03 10:12:03 +010077
78 ValidateAndCopyShape(GetOutputSlot(1).GetTensorInfo().GetShape(),
Mike Kelly377fb212023-01-10 15:55:28 +000079 inferredShapes[1],
Finn Williamsf24effa2020-07-03 10:12:03 +010080 m_ShapeInferenceMethod,
Finn Williams87d0bda2020-07-03 10:12:03 +010081 "DetectionPostProcessLayer", 1);
82
83 ValidateAndCopyShape(GetOutputSlot(2).GetTensorInfo().GetShape(),
Mike Kelly377fb212023-01-10 15:55:28 +000084 inferredShapes[2],
Finn Williamsf24effa2020-07-03 10:12:03 +010085 m_ShapeInferenceMethod,
Finn Williams87d0bda2020-07-03 10:12:03 +010086 "DetectionPostProcessLayer", 2);
87
88 ValidateAndCopyShape(GetOutputSlot(3).GetTensorInfo().GetShape(),
Mike Kelly377fb212023-01-10 15:55:28 +000089 inferredShapes[3],
Finn Williamsf24effa2020-07-03 10:12:03 +010090 m_ShapeInferenceMethod,
Finn Williams87d0bda2020-07-03 10:12:03 +010091 "DetectionPostProcessLayer", 3);
Narumol Prangnawarata0d56c72019-01-25 10:46:40 +000092}
93
Mike Kelly377fb212023-01-10 15:55:28 +000094std::vector<TensorShape> DetectionPostProcessLayer::InferOutputShapes(const std::vector<TensorShape>&) const
95{
96 unsigned int detectedBoxes = m_Param.m_MaxDetections * m_Param.m_MaxClassesPerDetection;
97
98 std::vector<TensorShape> results;
99 results.push_back({ 1, detectedBoxes, 4 });
100 results.push_back({ 1, detectedBoxes });
101 results.push_back({ 1, detectedBoxes });
102 results.push_back({ 1 });
103 return results;
104}
105
Matthew Benthamaeec3ce2023-02-23 13:03:46 +0000106Layer::ImmutableConstantTensors DetectionPostProcessLayer::GetConstantTensorsByRef() const
Narumol Prangnawarata0d56c72019-01-25 10:46:40 +0000107{
Nikhil Raj2e241752022-02-01 16:42:15 +0000108 // For API stability DO NOT ALTER order and add new members to the end of vector
Narumol Prangnawarata0d56c72019-01-25 10:46:40 +0000109 return { m_Anchors };
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000110}
111
Finn Williamsb454c5c2021-02-09 15:56:23 +0000112void DetectionPostProcessLayer::ExecuteStrategy(IStrategy& strategy) const
113{
Francis Murtagh4af56162021-04-20 16:37:55 +0100114 ManagedConstTensorHandle managedAnchors(m_Anchors);
115 std::vector<armnn::ConstTensor> constTensors { {managedAnchors.GetTensorInfo(), managedAnchors.Map()} };
Finn Williamsb454c5c2021-02-09 15:56:23 +0000116 strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
117}
118
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +0000119} // namespace armnn