Jakub Sujak | 433a595 | 2020-06-17 15:35:03 +0100 | [diff] [blame] | 1 | # Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
|
| 2 | # SPDX-License-Identifier: MIT
|
| 3 |
|
| 4 | """
|
| 5 | Contains functions specific to decoding and processing inference results for SSD Mobilenet V1 models.
|
| 6 | """
|
| 7 |
|
| 8 | import cv2
|
| 9 | import numpy as np
|
| 10 |
|
| 11 |
|
| 12 | def ssd_processing(output: np.ndarray, confidence_threshold=0.60):
|
| 13 | """
|
| 14 | Gets class, bounding box positions and confidence from the four outputs of the SSD model.
|
| 15 |
|
| 16 | Args:
|
| 17 | output: Vector of outputs from network.
|
| 18 | confidence_threshold: Selects only strong detections above this value.
|
| 19 |
|
| 20 | Returns:
|
| 21 | A list of detected objects in the form [class, [box positions], confidence]
|
| 22 | """
|
| 23 | if len(output) != 4:
|
| 24 | raise RuntimeError('Number of outputs from SSD model does not equal 4')
|
| 25 |
|
| 26 | position, classification, confidence, num_detections = [index[0] for index in output]
|
| 27 |
|
| 28 | detections = []
|
| 29 | for i in range(int(num_detections)):
|
| 30 | if confidence[i] > confidence_threshold:
|
| 31 | class_idx = classification[i]
|
| 32 | box = position[i, :4]
|
| 33 | # Reorder positions in format [x_min, y_min, x_max, y_max]
|
| 34 | box[0], box[1], box[2], box[3] = box[1], box[0], box[3], box[2]
|
| 35 | confidence_value = confidence[i]
|
| 36 | detections.append((class_idx, box, confidence_value))
|
| 37 | return detections
|
| 38 |
|
| 39 |
|
| 40 | def ssd_resize_factor(video: cv2.VideoCapture):
|
| 41 | """
|
| 42 | Gets a multiplier to scale the bounding box positions to
|
| 43 | their correct position in the frame.
|
| 44 |
|
| 45 | Args:
|
| 46 | video: Video capture object, contains information about data source.
|
| 47 |
|
| 48 | Returns:
|
| 49 | Resizing factor to scale box coordinates to output frame size.
|
| 50 | """
|
| 51 | frame_height = video.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
| 52 | frame_width = video.get(cv2.CAP_PROP_FRAME_WIDTH)
|
| 53 | return max(frame_height, frame_width)
|