Jakub Sujak | 433a595 | 2020-06-17 15:35:03 +0100 | [diff] [blame] | 1 | # Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
|
| 2 | # SPDX-License-Identifier: MIT
|
| 3 |
|
| 4 | """
|
| 5 | Contains functions specific to decoding and processing inference results for YOLO V3 Tiny models.
|
| 6 | """
|
| 7 |
|
| 8 | import cv2
|
| 9 | import numpy as np
|
| 10 |
|
| 11 |
|
| 12 | def iou(box1: list, box2: list):
|
| 13 | """
|
| 14 | Calculates the intersection-over-union (IoU) value for two bounding boxes.
|
| 15 |
|
| 16 | Args:
|
| 17 | box1: Array of positions for first bounding box
|
| 18 | in the form [x_min, y_min, x_max, y_max].
|
| 19 | box2: Array of positions for second bounding box.
|
| 20 |
|
| 21 | Returns:
|
| 22 | Calculated intersection-over-union (IoU) value for two bounding boxes.
|
| 23 | """
|
| 24 | area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
| 25 | area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
| 26 |
|
| 27 | if area_box1 <= 0 or area_box2 <= 0:
|
| 28 | iou_value = 0
|
| 29 | else:
|
| 30 | y_min_intersection = max(box1[1], box2[1])
|
| 31 | x_min_intersection = max(box1[0], box2[0])
|
| 32 | y_max_intersection = min(box1[3], box2[3])
|
| 33 | x_max_intersection = min(box1[2], box2[2])
|
| 34 |
|
| 35 | area_intersection = max(0, y_max_intersection - y_min_intersection) *\
|
| 36 | max(0, x_max_intersection - x_min_intersection)
|
| 37 | area_union = area_box1 + area_box2 - area_intersection
|
| 38 |
|
| 39 | try:
|
| 40 | iou_value = area_intersection / area_union
|
| 41 | except ZeroDivisionError:
|
| 42 | iou_value = 0
|
| 43 |
|
| 44 | return iou_value
|
| 45 |
|
| 46 |
|
| 47 | def yolo_processing(output: np.ndarray, confidence_threshold=0.40, iou_threshold=0.40):
|
| 48 | """
|
| 49 | Performs non-maximum suppression on input detections. Any detections
|
| 50 | with IOU value greater than given threshold are suppressed.
|
| 51 |
|
| 52 | Args:
|
| 53 | output: Vector of outputs from network.
|
| 54 | confidence_threshold: Selects only strong detections above this value.
|
| 55 | iou_threshold: Filters out boxes with IOU values above this value.
|
| 56 |
|
| 57 | Returns:
|
| 58 | A list of detected objects in the form [class, [box positions], confidence]
|
| 59 | """
|
| 60 | if len(output) != 1:
|
| 61 | raise RuntimeError('Number of outputs from YOLO model does not equal 1')
|
| 62 |
|
| 63 | # Find the array index of detections with confidence value above threshold
|
| 64 | confidence_det = output[0][:, :, 4][0]
|
| 65 | detections = list(np.where(confidence_det > confidence_threshold)[0])
|
| 66 | all_det, nms_det = [], []
|
| 67 |
|
| 68 | # Create list of all detections above confidence threshold
|
| 69 | for d in detections:
|
| 70 | box_positions = list(output[0][:, d, :4][0])
|
| 71 | confidence_score = output[0][:, d, 4][0]
|
| 72 | class_idx = np.argmax(output[0][:, d, 5:])
|
| 73 | all_det.append((class_idx, box_positions, confidence_score))
|
| 74 |
|
| 75 | # Suppress detections with IOU value above threshold
|
| 76 | while all_det:
|
| 77 | element = int(np.argmax([all_det[i][2] for i in range(len(all_det))]))
|
| 78 | nms_det.append(all_det.pop(element))
|
| 79 | all_det = [*filter(lambda x: (iou(x[1], nms_det[-1][1]) <= iou_threshold), [det for det in all_det])]
|
| 80 | return nms_det
|
| 81 |
|
| 82 |
|
| 83 | def yolo_resize_factor(video: cv2.VideoCapture, input_binding_info: tuple):
|
| 84 | """
|
| 85 | Gets a multiplier to scale the bounding box positions to
|
| 86 | their correct position in the frame.
|
| 87 |
|
| 88 | Args:
|
| 89 | video: Video capture object, contains information about data source.
|
| 90 | input_binding_info: Contains shape of model input layer.
|
| 91 |
|
| 92 | Returns:
|
| 93 | Resizing factor to scale box coordinates to output frame size.
|
| 94 | """
|
| 95 | frame_height = video.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
| 96 | frame_width = video.get(cv2.CAP_PROP_FRAME_WIDTH)
|
| 97 | model_height, model_width = list(input_binding_info[1].GetShape())[1:3]
|
| 98 | return max(frame_height, frame_width) / max(model_height, model_width)
|