blob: 7f05882fc41266015a1ef97987e461237aa0c640 [file] [log] [blame]
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "NetworkPipeline.hpp"
7#include "ImageUtils.hpp"
8
9namespace od
10{
11
12ObjDetectionPipeline::ObjDetectionPipeline(std::unique_ptr<ArmnnNetworkExecutor> executor,
13 std::unique_ptr<IDetectionResultDecoder> decoder) :
14 m_executor(std::move(executor)),
15 m_decoder(std::move(decoder)){}
16
17void od::ObjDetectionPipeline::Inference(const cv::Mat& processed, InferenceResults& result)
18{
19 m_executor->Run(processed.data, processed.total() * processed.elemSize(), result);
20}
21
22void ObjDetectionPipeline::PostProcessing(InferenceResults& inferenceResult,
23 const std::function<void (DetectedObjects)>& callback)
24{
25 DetectedObjects detections = m_decoder->Decode(inferenceResult, m_inputImageSize,
26 m_executor->GetImageAspectRatio(), {});
27 if (callback)
28 {
29 callback(detections);
30 }
31}
32
33void ObjDetectionPipeline::PreProcessing(const cv::Mat& frame, cv::Mat& processed)
34{
35 m_inputImageSize.m_Height = frame.rows;
36 m_inputImageSize.m_Width = frame.cols;
37 ResizeWithPad(frame, processed, m_processedFrame, m_executor->GetImageAspectRatio());
38}
39
40MobileNetSSDv1::MobileNetSSDv1(std::unique_ptr<ArmnnNetworkExecutor> executor,
41 float objectThreshold) :
42 ObjDetectionPipeline(std::move(executor),
43 std::make_unique<SSDResultDecoder>(objectThreshold))
44{}
45
46void MobileNetSSDv1::PreProcessing(const cv::Mat& frame, cv::Mat& processed)
47{
48 ObjDetectionPipeline::PreProcessing(frame, processed);
49 if (m_executor->GetInputDataType() == armnn::DataType::Float32)
50 {
51 // [0, 255] => [-1.0, 1.0]
52 processed.convertTo(processed, CV_32FC3, 1 / 127.5, -1);
53 }
54}
55
56YoloV3Tiny::YoloV3Tiny(std::unique_ptr<ArmnnNetworkExecutor> executor,
57 float NMSThreshold, float ClsThreshold, float ObjectThreshold) :
58 ObjDetectionPipeline(std::move(executor),
59 std::move(std::make_unique<YoloResultDecoder>(NMSThreshold,
60 ClsThreshold,
61 ObjectThreshold)))
62{}
63
64void YoloV3Tiny::PreProcessing(const cv::Mat& frame, cv::Mat& processed)
65{
66 ObjDetectionPipeline::PreProcessing(frame, processed);
67 if (m_executor->GetInputDataType() == armnn::DataType::Float32)
68 {
69 processed.convertTo(processed, CV_32FC3);
70 }
71}
72
73IPipelinePtr CreatePipeline(od::ODPipelineOptions& config)
74{
75 auto executor = std::make_unique<od::ArmnnNetworkExecutor>(config.m_ModelFilePath, config.m_backends);
76
77 if (config.m_ModelName == "SSD_MOBILE")
78 {
79 float detectionThreshold = 0.6;
80
81 return std::make_unique<od::MobileNetSSDv1>(std::move(executor),
82 detectionThreshold
83 );
84 }
85 else if (config.m_ModelName == "YOLO_V3_TINY")
86 {
87 float NMSThreshold = 0.6f;
88 float ClsThreshold = 0.6f;
89 float ObjectThreshold = 0.6f;
90 return std::make_unique<od::YoloV3Tiny>(std::move(executor),
91 NMSThreshold,
92 ClsThreshold,
93 ObjectThreshold
94 );
95 }
96 else
97 {
98 throw std::invalid_argument("Unknown Model name: " + config.m_ModelName + " supplied by user.");
99 }
100
101}
102}// namespace od