blob: e76ed7b2f4586a002d779fbd7e351a8ee8d97807 [file] [log] [blame]
Jakub Sujak433a5952020-06-17 15:35:03 +01001# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4"""
5Contains functions specific to decoding and processing inference results for YOLO V3 Tiny models.
6"""
7
8import cv2
9import numpy as np
10
11
12def iou(box1: list, box2: list):
13 """
14 Calculates the intersection-over-union (IoU) value for two bounding boxes.
15
16 Args:
17 box1: Array of positions for first bounding box
18 in the form [x_min, y_min, x_max, y_max].
19 box2: Array of positions for second bounding box.
20
21 Returns:
22 Calculated intersection-over-union (IoU) value for two bounding boxes.
23 """
24 area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
25 area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
26
27 if area_box1 <= 0 or area_box2 <= 0:
28 iou_value = 0
29 else:
30 y_min_intersection = max(box1[1], box2[1])
31 x_min_intersection = max(box1[0], box2[0])
32 y_max_intersection = min(box1[3], box2[3])
33 x_max_intersection = min(box1[2], box2[2])
34
35 area_intersection = max(0, y_max_intersection - y_min_intersection) *\
36 max(0, x_max_intersection - x_min_intersection)
37 area_union = area_box1 + area_box2 - area_intersection
38
39 try:
40 iou_value = area_intersection / area_union
41 except ZeroDivisionError:
42 iou_value = 0
43
44 return iou_value
45
46
47def yolo_processing(output: np.ndarray, confidence_threshold=0.40, iou_threshold=0.40):
48 """
49 Performs non-maximum suppression on input detections. Any detections
50 with IOU value greater than given threshold are suppressed.
51
52 Args:
53 output: Vector of outputs from network.
54 confidence_threshold: Selects only strong detections above this value.
55 iou_threshold: Filters out boxes with IOU values above this value.
56
57 Returns:
58 A list of detected objects in the form [class, [box positions], confidence]
59 """
60 if len(output) != 1:
61 raise RuntimeError('Number of outputs from YOLO model does not equal 1')
62
63 # Find the array index of detections with confidence value above threshold
64 confidence_det = output[0][:, :, 4][0]
65 detections = list(np.where(confidence_det > confidence_threshold)[0])
66 all_det, nms_det = [], []
67
68 # Create list of all detections above confidence threshold
69 for d in detections:
70 box_positions = list(output[0][:, d, :4][0])
71 confidence_score = output[0][:, d, 4][0]
72 class_idx = np.argmax(output[0][:, d, 5:])
73 all_det.append((class_idx, box_positions, confidence_score))
74
75 # Suppress detections with IOU value above threshold
76 while all_det:
77 element = int(np.argmax([all_det[i][2] for i in range(len(all_det))]))
78 nms_det.append(all_det.pop(element))
79 all_det = [*filter(lambda x: (iou(x[1], nms_det[-1][1]) <= iou_threshold), [det for det in all_det])]
80 return nms_det
81
82
Raviv Shalev97ddc062021-12-07 15:18:09 +020083def yolo_resize_factor(video: cv2.VideoCapture, input_data_shape: tuple):
Jakub Sujak433a5952020-06-17 15:35:03 +010084 """
85 Gets a multiplier to scale the bounding box positions to
86 their correct position in the frame.
87
88 Args:
89 video: Video capture object, contains information about data source.
Raviv Shalev97ddc062021-12-07 15:18:09 +020090 input_data_shape: Contains shape of model input layer.
Jakub Sujak433a5952020-06-17 15:35:03 +010091
92 Returns:
93 Resizing factor to scale box coordinates to output frame size.
94 """
95 frame_height = video.get(cv2.CAP_PROP_FRAME_HEIGHT)
96 frame_width = video.get(cv2.CAP_PROP_FRAME_WIDTH)
Raviv Shalev97ddc062021-12-07 15:18:09 +020097 _, model_height, model_width, _= input_data_shape
Jakub Sujak433a5952020-06-17 15:35:03 +010098 return max(frame_height, frame_width) / max(model_height, model_width)