blob: 61aa46c3d7991a123fa8b96301a212f9c39c4f44 [file] [log] [blame]
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +00001# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4"""
5This file contains helper functions for reading video/image data and
6 pre/postprocessing of video/image data using OpenCV.
7"""
8
9import os
10
11import cv2
12import numpy as np
13
14import pyarmnn as ann
15
16
17def preprocess(frame: np.ndarray, input_binding_info: tuple):
18 """
19 Takes a frame, resizes, swaps channels and converts data type to match
20 model input layer. The converted frame is wrapped in a const tensor
21 and bound to the input tensor.
22
23 Args:
24 frame: Captured frame from video.
25 input_binding_info: Contains shape and data type of model input layer.
26
27 Returns:
28 Input tensor.
29 """
30 # Swap channels and resize frame to model resolution
31 frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
32 resized_frame = resize_with_aspect_ratio(frame, input_binding_info)
33
34 # Expand dimensions and convert data type to match model input
35 data_type = np.float32 if input_binding_info[1].GetDataType() == ann.DataType_Float32 else np.uint8
36 resized_frame = np.expand_dims(np.asarray(resized_frame, dtype=data_type), axis=0)
37 assert resized_frame.shape == tuple(input_binding_info[1].GetShape())
38
39 input_tensors = ann.make_input_tensors([input_binding_info], [resized_frame])
40 return input_tensors
41
42
43def resize_with_aspect_ratio(frame: np.ndarray, input_binding_info: tuple):
44 """
45 Resizes frame while maintaining aspect ratio, padding any empty space.
46
47 Args:
48 frame: Captured frame.
49 input_binding_info: Contains shape of model input layer.
50
51 Returns:
52 Frame resized to the size of model input layer.
53 """
54 aspect_ratio = frame.shape[1] / frame.shape[0]
55 model_height, model_width = list(input_binding_info[1].GetShape())[1:3]
56
57 if aspect_ratio >= 1.0:
58 new_height, new_width = int(model_width / aspect_ratio), model_width
59 b_padding, r_padding = model_height - new_height, 0
60 else:
61 new_height, new_width = model_height, int(model_height * aspect_ratio)
62 b_padding, r_padding = 0, model_width - new_width
63
64 # Resize and pad any empty space
65 frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
66 frame = cv2.copyMakeBorder(frame, top=0, bottom=b_padding, left=0, right=r_padding,
67 borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0])
68 return frame
69
70
71def create_video_writer(video: cv2.VideoCapture, video_path: str, output_path: str):
72 """
73 Creates a video writer object to write processed frames to file.
74
75 Args:
76 video: Video capture object, contains information about data source.
77 video_path: User-specified video file path.
78 output_path: Optional path to save the processed video.
79
80 Returns:
81 Video writer object.
82 """
83 _, ext = os.path.splitext(video_path)
84
85 if output_path is not None:
86 assert os.path.isdir(output_path)
87
88 i, filename = 0, os.path.join(output_path if output_path is not None else str(), f'object_detection_demo{ext}')
89 while os.path.exists(filename):
90 i += 1
91 filename = os.path.join(output_path if output_path is not None else str(), f'object_detection_demo({i}){ext}')
92
93 video_writer = cv2.VideoWriter(filename=filename,
94 fourcc=get_source_encoding_int(video),
95 fps=int(video.get(cv2.CAP_PROP_FPS)),
96 frameSize=(int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
97 int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))))
98 return video_writer
99
100
101def init_video_file_capture(video_path: str, output_path: str):
102 """
103 Creates a video capture object from a video file.
104
105 Args:
106 video_path: User-specified video file path.
107 output_path: Optional path to save the processed video.
108
109 Returns:
110 Video capture object to capture frames, video writer object to write processed
111 frames to file, plus total frame count of video source to iterate through.
112 """
113 if not os.path.exists(video_path):
114 raise FileNotFoundError(f'Video file not found for: {video_path}')
115 video = cv2.VideoCapture(video_path)
116 if not video.isOpened:
117 raise RuntimeError(f'Failed to open video capture from file: {video_path}')
118
119 video_writer = create_video_writer(video, video_path, output_path)
120 iter_frame_count = range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))
121 return video, video_writer, iter_frame_count
122
123
124def init_video_stream_capture(video_source: int):
125 """
126 Creates a video capture object from a device.
127
128 Args:
129 video_source: Device index used to read video stream.
130
131 Returns:
132 Video capture object used to capture frames from a video stream.
133 """
134 video = cv2.VideoCapture(video_source)
135 if not video.isOpened:
136 raise RuntimeError(f'Failed to open video capture for device with index: {video_source}')
137 print('Processing video stream. Press \'Esc\' key to exit the demo.')
138 return video
139
140
141def draw_bounding_boxes(frame: np.ndarray, detections: list, resize_factor, labels: dict):
142 """
143 Draws bounding boxes around detected objects and adds a label and confidence score.
144
145 Args:
146 frame: The original captured frame from video source.
147 detections: A list of detected objects in the form [class, [box positions], confidence].
148 resize_factor: Resizing factor to scale box coordinates to output frame size.
149 labels: Dictionary of labels and colors keyed on the classification index.
150 """
151 for detection in detections:
152 class_idx, box, confidence = [d for d in detection]
153 label, color = labels[class_idx][0].capitalize(), labels[class_idx][1]
154
155 # Obtain frame size and resized bounding box positions
156 frame_height, frame_width = frame.shape[:2]
157 x_min, y_min, x_max, y_max = [int(position * resize_factor) for position in box]
158
159 # Ensure box stays within the frame
160 x_min, y_min = max(0, x_min), max(0, y_min)
161 x_max, y_max = min(frame_width, x_max), min(frame_height, y_max)
162
163 # Draw bounding box around detected object
164 cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), color, 2)
165
166 # Create label for detected object class
167 label = f'{label} {confidence * 100:.1f}%'
168 label_color = (0, 0, 0) if sum(color)>200 else (255, 255, 255)
169
170 # Make sure label always stays on-screen
171 x_text, y_text = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, 1, 1)[0][:2]
172
173 lbl_box_xy_min = (x_min, y_min if y_min<25 else y_min - y_text)
174 lbl_box_xy_max = (x_min + int(0.55 * x_text), y_min + y_text if y_min<25 else y_min)
175 lbl_text_pos = (x_min + 5, y_min + 16 if y_min<25 else y_min - 5)
176
177 # Add label and confidence value
178 cv2.rectangle(frame, lbl_box_xy_min, lbl_box_xy_max, color, -1)
179 cv2.putText(frame, label, lbl_text_pos, cv2.FONT_HERSHEY_DUPLEX, 0.50,
180 label_color, 1, cv2.LINE_AA)
181
182
183def get_source_encoding_int(video_capture):
184 return int(video_capture.get(cv2.CAP_PROP_FOURCC))