blob: fd848b8b0f47074f1a84dfe86ea7f9104f1d529a [file] [log] [blame]
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +00001# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4"""
5This file contains helper functions for reading video/image data and
6 pre/postprocessing of video/image data using OpenCV.
7"""
8
9import os
10
11import cv2
12import numpy as np
13
14import pyarmnn as ann
15
16
17def preprocess(frame: np.ndarray, input_binding_info: tuple):
18 """
19 Takes a frame, resizes, swaps channels and converts data type to match
20 model input layer. The converted frame is wrapped in a const tensor
21 and bound to the input tensor.
22
23 Args:
24 frame: Captured frame from video.
25 input_binding_info: Contains shape and data type of model input layer.
26
27 Returns:
28 Input tensor.
29 """
30 # Swap channels and resize frame to model resolution
31 frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
32 resized_frame = resize_with_aspect_ratio(frame, input_binding_info)
33
34 # Expand dimensions and convert data type to match model input
Éanna Ó Catháin65d5d2d2021-08-20 14:41:38 +010035 if input_binding_info[1].GetDataType() == ann.DataType_Float32:
36 data_type = np.float32
37 resized_frame = resized_frame.astype("float32")/255
38 else:
39 data_type = np.uint8
40
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000041 resized_frame = np.expand_dims(np.asarray(resized_frame, dtype=data_type), axis=0)
42 assert resized_frame.shape == tuple(input_binding_info[1].GetShape())
43
44 input_tensors = ann.make_input_tensors([input_binding_info], [resized_frame])
45 return input_tensors
46
47
Éanna Ó Catháin65d5d2d2021-08-20 14:41:38 +010048
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000049def resize_with_aspect_ratio(frame: np.ndarray, input_binding_info: tuple):
50 """
51 Resizes frame while maintaining aspect ratio, padding any empty space.
52
53 Args:
54 frame: Captured frame.
55 input_binding_info: Contains shape of model input layer.
56
57 Returns:
58 Frame resized to the size of model input layer.
59 """
60 aspect_ratio = frame.shape[1] / frame.shape[0]
61 model_height, model_width = list(input_binding_info[1].GetShape())[1:3]
62
63 if aspect_ratio >= 1.0:
64 new_height, new_width = int(model_width / aspect_ratio), model_width
65 b_padding, r_padding = model_height - new_height, 0
66 else:
67 new_height, new_width = model_height, int(model_height * aspect_ratio)
68 b_padding, r_padding = 0, model_width - new_width
69
70 # Resize and pad any empty space
71 frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
72 frame = cv2.copyMakeBorder(frame, top=0, bottom=b_padding, left=0, right=r_padding,
73 borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0])
74 return frame
75
76
77def create_video_writer(video: cv2.VideoCapture, video_path: str, output_path: str):
78 """
79 Creates a video writer object to write processed frames to file.
80
81 Args:
82 video: Video capture object, contains information about data source.
83 video_path: User-specified video file path.
84 output_path: Optional path to save the processed video.
85
86 Returns:
87 Video writer object.
88 """
89 _, ext = os.path.splitext(video_path)
90
91 if output_path is not None:
92 assert os.path.isdir(output_path)
93
94 i, filename = 0, os.path.join(output_path if output_path is not None else str(), f'object_detection_demo{ext}')
95 while os.path.exists(filename):
96 i += 1
97 filename = os.path.join(output_path if output_path is not None else str(), f'object_detection_demo({i}){ext}')
98
99 video_writer = cv2.VideoWriter(filename=filename,
100 fourcc=get_source_encoding_int(video),
101 fps=int(video.get(cv2.CAP_PROP_FPS)),
102 frameSize=(int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
103 int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))))
104 return video_writer
105
106
107def init_video_file_capture(video_path: str, output_path: str):
108 """
109 Creates a video capture object from a video file.
110
111 Args:
112 video_path: User-specified video file path.
113 output_path: Optional path to save the processed video.
114
115 Returns:
116 Video capture object to capture frames, video writer object to write processed
117 frames to file, plus total frame count of video source to iterate through.
118 """
119 if not os.path.exists(video_path):
120 raise FileNotFoundError(f'Video file not found for: {video_path}')
121 video = cv2.VideoCapture(video_path)
122 if not video.isOpened:
123 raise RuntimeError(f'Failed to open video capture from file: {video_path}')
124
125 video_writer = create_video_writer(video, video_path, output_path)
126 iter_frame_count = range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))
127 return video, video_writer, iter_frame_count
128
129
130def init_video_stream_capture(video_source: int):
131 """
132 Creates a video capture object from a device.
133
134 Args:
135 video_source: Device index used to read video stream.
136
137 Returns:
138 Video capture object used to capture frames from a video stream.
139 """
140 video = cv2.VideoCapture(video_source)
141 if not video.isOpened:
142 raise RuntimeError(f'Failed to open video capture for device with index: {video_source}')
143 print('Processing video stream. Press \'Esc\' key to exit the demo.')
144 return video
145
146
147def draw_bounding_boxes(frame: np.ndarray, detections: list, resize_factor, labels: dict):
148 """
149 Draws bounding boxes around detected objects and adds a label and confidence score.
150
151 Args:
152 frame: The original captured frame from video source.
153 detections: A list of detected objects in the form [class, [box positions], confidence].
154 resize_factor: Resizing factor to scale box coordinates to output frame size.
155 labels: Dictionary of labels and colors keyed on the classification index.
156 """
157 for detection in detections:
158 class_idx, box, confidence = [d for d in detection]
159 label, color = labels[class_idx][0].capitalize(), labels[class_idx][1]
160
161 # Obtain frame size and resized bounding box positions
162 frame_height, frame_width = frame.shape[:2]
163 x_min, y_min, x_max, y_max = [int(position * resize_factor) for position in box]
164
165 # Ensure box stays within the frame
166 x_min, y_min = max(0, x_min), max(0, y_min)
167 x_max, y_max = min(frame_width, x_max), min(frame_height, y_max)
168
169 # Draw bounding box around detected object
170 cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), color, 2)
171
172 # Create label for detected object class
173 label = f'{label} {confidence * 100:.1f}%'
174 label_color = (0, 0, 0) if sum(color)>200 else (255, 255, 255)
175
176 # Make sure label always stays on-screen
177 x_text, y_text = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, 1, 1)[0][:2]
178
179 lbl_box_xy_min = (x_min, y_min if y_min<25 else y_min - y_text)
180 lbl_box_xy_max = (x_min + int(0.55 * x_text), y_min + y_text if y_min<25 else y_min)
181 lbl_text_pos = (x_min + 5, y_min + 16 if y_min<25 else y_min - 5)
182
183 # Add label and confidence value
184 cv2.rectangle(frame, lbl_box_xy_min, lbl_box_xy_max, color, -1)
185 cv2.putText(frame, label, lbl_text_pos, cv2.FONT_HERSHEY_DUPLEX, 0.50,
186 label_color, 1, cv2.LINE_AA)
187
188
189def get_source_encoding_int(video_capture):
190 return int(video_capture.get(cv2.CAP_PROP_FOURCC))