blob: 6dfd1abf84884a11b987191b1b07a75a731f391f [file] [log] [blame]
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "SSDResultDecoder.hpp"
7
8#include <cassert>
9#include <algorithm>
10#include <cmath>
11#include <stdexcept>
12namespace od
13{
14
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010015DetectedObjects SSDResultDecoder::Decode(const common::InferenceResults<float>& networkResults,
16 const common::Size& outputFrameSize,
17 const common::Size& resizedFrameSize,
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +010018 const std::vector<std::string>& labels)
19{
20 // SSD network outputs 4 tensors: bounding boxes, labels, probabilities, number of detections.
21 if (networkResults.size() != 4)
22 {
23 throw std::runtime_error("Number of outputs from SSD model doesn't equal 4");
24 }
25
26 DetectedObjects detectedObjects;
27 const int numDetections = static_cast<int>(std::lround(networkResults[3][0]));
28
29 double longEdgeInput = std::max(resizedFrameSize.m_Width, resizedFrameSize.m_Height);
30 double longEdgeOutput = std::max(outputFrameSize.m_Width, outputFrameSize.m_Height);
31 const double resizeFactor = longEdgeOutput/longEdgeInput;
32
33 for (int i=0; i<numDetections; ++i)
34 {
35 if (networkResults[2][i] > m_objectThreshold)
36 {
37 DetectedObject detectedObject;
38 detectedObject.SetScore(networkResults[2][i]);
39 auto classId = std::lround(networkResults[1][i]);
40
41 if (classId < labels.size())
42 {
43 detectedObject.SetLabel(labels[classId]);
44 }
45 else
46 {
47 detectedObject.SetLabel(std::to_string(classId));
48 }
49 detectedObject.SetId(classId);
50
51 // Convert SSD bbox outputs (ratios of image size) to pixel values.
52 double topLeftY = networkResults[0][i*4 + 0] * resizedFrameSize.m_Height;
53 double topLeftX = networkResults[0][i*4 + 1] * resizedFrameSize.m_Width;
54 double botRightY = networkResults[0][i*4 + 2] * resizedFrameSize.m_Height;
55 double botRightX = networkResults[0][i*4 + 3] * resizedFrameSize.m_Width;
56
57 // Scale the coordinates to output frame size.
58 topLeftY *= resizeFactor;
59 topLeftX *= resizeFactor;
60 botRightY *= resizeFactor;
61 botRightX *= resizeFactor;
62
63 assert(botRightX > topLeftX);
64 assert(botRightY > topLeftY);
65
66 // Internal BoundingBox stores box top left x,y and width, height.
67 detectedObject.SetBoundingBox({static_cast<int>(std::round(topLeftX)),
68 static_cast<int>(std::round(topLeftY)),
69 static_cast<unsigned int>(botRightX - topLeftX),
70 static_cast<unsigned int>(botRightY - topLeftY)});
71
72 detectedObjects.emplace_back(detectedObject);
73 }
74 }
75 return detectedObjects;
76}
77
78SSDResultDecoder::SSDResultDecoder(float ObjectThreshold) : m_objectThreshold(ObjectThreshold) {}
79
80}// namespace od