blob: f177802f8a68a9dadc58857bc3da8b037995179c [file] [log] [blame]
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "YoloResultDecoder.hpp"
7
8#include "NonMaxSuppression.hpp"
9
10#include <cassert>
11#include <stdexcept>
12
13namespace od
14{
15
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010016DetectedObjects YoloResultDecoder::Decode(const common::InferenceResults<float>& networkResults,
17 const common::Size& outputFrameSize,
18 const common::Size& resizedFrameSize,
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010019 const std::vector<std::string>& labels)
20{
21
22 // Yolo v3 network outputs 1 tensor
23 if (networkResults.size() != 1)
24 {
25 throw std::runtime_error("Number of outputs from Yolo model doesn't equal 1");
26 }
27 auto element_step = m_boxElements + m_confidenceElements + m_numClasses;
28
29 float longEdgeInput = std::max(resizedFrameSize.m_Width, resizedFrameSize.m_Height);
30 float longEdgeOutput = std::max(outputFrameSize.m_Width, outputFrameSize.m_Height);
31 const float resizeFactor = longEdgeOutput/longEdgeInput;
32
33 DetectedObjects detectedObjects;
34 DetectedObjects resultsAfterNMS;
35
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010036 for (const common::InferenceResult<float>& result : networkResults)
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010037 {
38 for (unsigned int i = 0; i < m_numBoxes; ++i)
39 {
40 const float* cur_box = &result[i * element_step];
41 // Objectness score
42 if (cur_box[4] > m_objectThreshold)
43 {
44 for (unsigned int classIndex = 0; classIndex < m_numClasses; ++classIndex)
45 {
46 const float class_prob = cur_box[4] * cur_box[5 + classIndex];
47
48 // class confidence
49
50 if (class_prob > m_ClsThreshold)
51 {
52 DetectedObject detectedObject;
53
54 detectedObject.SetScore(class_prob);
55
56 float topLeftX = cur_box[0] * resizeFactor;
57 float topLeftY = cur_box[1] * resizeFactor;
58 float botRightX = cur_box[2] * resizeFactor;
59 float botRightY = cur_box[3] * resizeFactor;
60
61 assert(botRightX > topLeftX);
62 assert(botRightY > topLeftY);
63
64 detectedObject.SetBoundingBox({static_cast<int>(topLeftX),
65 static_cast<int>(topLeftY),
66 static_cast<unsigned int>(botRightX-topLeftX),
67 static_cast<unsigned int>(botRightY-topLeftY)});
68 if(labels.size() > classIndex)
69 {
70 detectedObject.SetLabel(labels.at(classIndex));
71 }
72 else
73 {
74 detectedObject.SetLabel(std::to_string(classIndex));
75 }
76 detectedObject.SetId(classIndex);
77 detectedObjects.emplace_back(detectedObject);
78 }
79 }
80 }
81 }
82
83 std::vector<int> keepIndiciesAfterNMS = od::NonMaxSuppression(detectedObjects, m_NmsThreshold);
84
85 for (const int ind: keepIndiciesAfterNMS)
86 {
87 resultsAfterNMS.emplace_back(detectedObjects[ind]);
88 }
89 }
90
91 return resultsAfterNMS;
92}
93
94YoloResultDecoder::YoloResultDecoder(float NMSThreshold, float ClsThreshold, float ObjectThreshold)
95 : m_NmsThreshold(NMSThreshold), m_ClsThreshold(ClsThreshold), m_objectThreshold(ObjectThreshold) {}
96
97}// namespace od
98
99
100