blob: 2c4a76d35a61f0a0c971186d263a5debd49bc0c9 [file] [log] [blame]
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +01006#include "ObjectDetectionPipeline.hpp"
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +01007#include "ImageUtils.hpp"
8
9namespace od
10{
11
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010012ObjDetectionPipeline::ObjDetectionPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010013 std::unique_ptr<IDetectionResultDecoder> decoder) :
Eanna O Cathain2f0ddb62022-03-03 15:58:10 +000014 m_executor(std::move(executor)),
15 m_decoder(std::move(decoder)){}
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010016
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010017void od::ObjDetectionPipeline::Inference(const cv::Mat& processed, common::InferenceResults<float>& result)
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010018{
19 m_executor->Run(processed.data, processed.total() * processed.elemSize(), result);
20}
21
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010022void ObjDetectionPipeline::PostProcessing(common::InferenceResults<float>& inferenceResult,
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010023 const std::function<void (DetectedObjects)>& callback)
24{
25 DetectedObjects detections = m_decoder->Decode(inferenceResult, m_inputImageSize,
26 m_executor->GetImageAspectRatio(), {});
27 if (callback)
28 {
29 callback(detections);
30 }
31}
32
33void ObjDetectionPipeline::PreProcessing(const cv::Mat& frame, cv::Mat& processed)
34{
35 m_inputImageSize.m_Height = frame.rows;
36 m_inputImageSize.m_Width = frame.cols;
37 ResizeWithPad(frame, processed, m_processedFrame, m_executor->GetImageAspectRatio());
38}
39
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010040MobileNetSSDv1::MobileNetSSDv1(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010041 float objectThreshold) :
Eanna O Cathain2f0ddb62022-03-03 15:58:10 +000042 ObjDetectionPipeline(std::move(executor),
43 std::make_unique<SSDResultDecoder>(objectThreshold))
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010044{}
45
46void MobileNetSSDv1::PreProcessing(const cv::Mat& frame, cv::Mat& processed)
47{
48 ObjDetectionPipeline::PreProcessing(frame, processed);
49 if (m_executor->GetInputDataType() == armnn::DataType::Float32)
50 {
51 // [0, 255] => [-1.0, 1.0]
52 processed.convertTo(processed, CV_32FC3, 1 / 127.5, -1);
53 }
54}
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010055YoloV3Tiny::YoloV3Tiny(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010056 float NMSThreshold, float ClsThreshold, float ObjectThreshold) :
Eanna O Cathain2f0ddb62022-03-03 15:58:10 +000057 ObjDetectionPipeline(std::move(executor),
58 std::move(std::make_unique<YoloResultDecoder>(NMSThreshold,
59 ClsThreshold,
60 ObjectThreshold)))
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010061{}
62
63void YoloV3Tiny::PreProcessing(const cv::Mat& frame, cv::Mat& processed)
64{
65 ObjDetectionPipeline::PreProcessing(frame, processed);
66 if (m_executor->GetInputDataType() == armnn::DataType::Float32)
67 {
68 processed.convertTo(processed, CV_32FC3);
69 }
70}
71
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010072IPipelinePtr CreatePipeline(common::PipelineOptions& config)
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010073{
Eanna O Cathain2f0ddb62022-03-03 15:58:10 +000074 auto executor = std::make_unique<common::ArmnnNetworkExecutor<float>>(config.m_ModelFilePath,
75 config.m_backends,
76 config.m_ProfilingEnabled);
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010077 if (config.m_ModelName == "SSD_MOBILE")
78 {
Eanna O Cathain2f0ddb62022-03-03 15:58:10 +000079 float detectionThreshold = 0.5;
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010080
81 return std::make_unique<od::MobileNetSSDv1>(std::move(executor),
82 detectionThreshold
83 );
84 }
85 else if (config.m_ModelName == "YOLO_V3_TINY")
86 {
87 float NMSThreshold = 0.6f;
88 float ClsThreshold = 0.6f;
89 float ObjectThreshold = 0.6f;
90 return std::make_unique<od::YoloV3Tiny>(std::move(executor),
91 NMSThreshold,
92 ClsThreshold,
93 ObjectThreshold
94 );
95 }
96 else
97 {
98 throw std::invalid_argument("Unknown Model name: " + config.m_ModelName + " supplied by user.");
99 }
100
101}
Eanna O Cathain2f0ddb62022-03-03 15:58:10 +0000102}// namespace od