blob: 36d10392278c5a9ac2e04cc16df69f5dd51ab409 [file] [log] [blame]
Raviv Shalev97ddc062021-12-07 15:18:09 +02001# Copyright © 2020-2022 Arm Ltd and Contributors. All rights reserved.
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +00002# SPDX-License-Identifier: MIT
3
4"""
5This file contains helper functions for reading video/image data and
6 pre/postprocessing of video/image data using OpenCV.
7"""
8
9import os
10
11import cv2
12import numpy as np
13
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000014
Raviv Shalev97ddc062021-12-07 15:18:09 +020015def preprocess(frame: np.ndarray, input_data_type, input_data_shape: tuple, is_normalised: bool,
16 keep_aspect_ratio: bool=True):
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000017 """
18 Takes a frame, resizes, swaps channels and converts data type to match
Raviv Shalev97ddc062021-12-07 15:18:09 +020019 model input layer.
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000020
21 Args:
22 frame: Captured frame from video.
Raviv Shalev97ddc062021-12-07 15:18:09 +020023 input_data_type: Contains data type of model input layer.
24 input_data_shape: Contains shape of model input layer.
alexanderf42f5682021-07-16 11:30:56 +010025 is_normalised: if the input layer expects normalised data
Raviv Shalev97ddc062021-12-07 15:18:09 +020026 keep_aspect_ratio: Network executor's input data aspect ratio
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000027
28 Returns:
29 Input tensor.
30 """
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000031 frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000032
Raviv Shalev97ddc062021-12-07 15:18:09 +020033 if keep_aspect_ratio:
34 # Swap channels and resize frame to model resolution
35 resized_frame = resize_with_aspect_ratio(frame, input_data_shape)
36 else:
37 # select the height and width from input_data_shape
38 frame_height = input_data_shape[1]
39 frame_width = input_data_shape[2]
40 resized_frame = cv2.resize(frame, (frame_width, frame_height))
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000041 # Expand dimensions and convert data type to match model input
Raviv Shalev97ddc062021-12-07 15:18:09 +020042 if np.float32 == input_data_type:
Éanna Ó Catháin65d5d2d2021-08-20 14:41:38 +010043 data_type = np.float32
alexanderf42f5682021-07-16 11:30:56 +010044 if is_normalised:
45 resized_frame = resized_frame.astype("float32")/255
Éanna Ó Catháin65d5d2d2021-08-20 14:41:38 +010046 else:
47 data_type = np.uint8
48
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000049 resized_frame = np.expand_dims(np.asarray(resized_frame, dtype=data_type), axis=0)
Raviv Shalev97ddc062021-12-07 15:18:09 +020050 assert resized_frame.shape == input_data_shape
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000051
Raviv Shalev97ddc062021-12-07 15:18:09 +020052 return resized_frame
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000053
54
Raviv Shalev97ddc062021-12-07 15:18:09 +020055def resize_with_aspect_ratio(frame: np.ndarray, input_data_shape: tuple):
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000056 """
57 Resizes frame while maintaining aspect ratio, padding any empty space.
58
59 Args:
60 frame: Captured frame.
Raviv Shalev97ddc062021-12-07 15:18:09 +020061 input_data_shape: Contains shape of model input layer.
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000062
63 Returns:
64 Frame resized to the size of model input layer.
65 """
66 aspect_ratio = frame.shape[1] / frame.shape[0]
Raviv Shalev97ddc062021-12-07 15:18:09 +020067 _, model_height, model_width, _ = input_data_shape
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000068
69 if aspect_ratio >= 1.0:
70 new_height, new_width = int(model_width / aspect_ratio), model_width
71 b_padding, r_padding = model_height - new_height, 0
72 else:
73 new_height, new_width = model_height, int(model_height * aspect_ratio)
74 b_padding, r_padding = 0, model_width - new_width
75
76 # Resize and pad any empty space
77 frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
78 frame = cv2.copyMakeBorder(frame, top=0, bottom=b_padding, left=0, right=r_padding,
79 borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0])
80 return frame
81
82
83def create_video_writer(video: cv2.VideoCapture, video_path: str, output_path: str):
84 """
85 Creates a video writer object to write processed frames to file.
86
87 Args:
88 video: Video capture object, contains information about data source.
89 video_path: User-specified video file path.
90 output_path: Optional path to save the processed video.
91
92 Returns:
93 Video writer object.
94 """
95 _, ext = os.path.splitext(video_path)
96
97 if output_path is not None:
98 assert os.path.isdir(output_path)
99
100 i, filename = 0, os.path.join(output_path if output_path is not None else str(), f'object_detection_demo{ext}')
101 while os.path.exists(filename):
102 i += 1
103 filename = os.path.join(output_path if output_path is not None else str(), f'object_detection_demo({i}){ext}')
104
105 video_writer = cv2.VideoWriter(filename=filename,
106 fourcc=get_source_encoding_int(video),
107 fps=int(video.get(cv2.CAP_PROP_FPS)),
108 frameSize=(int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
109 int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))))
110 return video_writer
111
112
113def init_video_file_capture(video_path: str, output_path: str):
114 """
115 Creates a video capture object from a video file.
116
117 Args:
118 video_path: User-specified video file path.
119 output_path: Optional path to save the processed video.
120
121 Returns:
122 Video capture object to capture frames, video writer object to write processed
123 frames to file, plus total frame count of video source to iterate through.
124 """
125 if not os.path.exists(video_path):
126 raise FileNotFoundError(f'Video file not found for: {video_path}')
127 video = cv2.VideoCapture(video_path)
128 if not video.isOpened:
129 raise RuntimeError(f'Failed to open video capture from file: {video_path}')
130
131 video_writer = create_video_writer(video, video_path, output_path)
132 iter_frame_count = range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))
133 return video, video_writer, iter_frame_count
134
135
136def init_video_stream_capture(video_source: int):
137 """
138 Creates a video capture object from a device.
139
140 Args:
141 video_source: Device index used to read video stream.
142
143 Returns:
144 Video capture object used to capture frames from a video stream.
145 """
146 video = cv2.VideoCapture(video_source)
147 if not video.isOpened:
148 raise RuntimeError(f'Failed to open video capture for device with index: {video_source}')
149 print('Processing video stream. Press \'Esc\' key to exit the demo.')
150 return video
151
152
153def draw_bounding_boxes(frame: np.ndarray, detections: list, resize_factor, labels: dict):
154 """
155 Draws bounding boxes around detected objects and adds a label and confidence score.
156
157 Args:
158 frame: The original captured frame from video source.
159 detections: A list of detected objects in the form [class, [box positions], confidence].
160 resize_factor: Resizing factor to scale box coordinates to output frame size.
161 labels: Dictionary of labels and colors keyed on the classification index.
162 """
163 for detection in detections:
164 class_idx, box, confidence = [d for d in detection]
165 label, color = labels[class_idx][0].capitalize(), labels[class_idx][1]
166
167 # Obtain frame size and resized bounding box positions
168 frame_height, frame_width = frame.shape[:2]
169 x_min, y_min, x_max, y_max = [int(position * resize_factor) for position in box]
170
171 # Ensure box stays within the frame
172 x_min, y_min = max(0, x_min), max(0, y_min)
173 x_max, y_max = min(frame_width, x_max), min(frame_height, y_max)
174
175 # Draw bounding box around detected object
176 cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), color, 2)
177
178 # Create label for detected object class
179 label = f'{label} {confidence * 100:.1f}%'
Raviv Shalev97ddc062021-12-07 15:18:09 +0200180 label_color = (0, 0, 0) if sum(color) > 200 else (255, 255, 255)
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +0000181
182 # Make sure label always stays on-screen
183 x_text, y_text = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, 1, 1)[0][:2]
184
185 lbl_box_xy_min = (x_min, y_min if y_min<25 else y_min - y_text)
186 lbl_box_xy_max = (x_min + int(0.55 * x_text), y_min + y_text if y_min<25 else y_min)
Raviv Shalev97ddc062021-12-07 15:18:09 +0200187 lbl_text_pos = (x_min + 5, y_min + 16 if y_min < 25 else y_min - 5)
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +0000188
189 # Add label and confidence value
190 cv2.rectangle(frame, lbl_box_xy_min, lbl_box_xy_max, color, -1)
191 cv2.putText(frame, label, lbl_text_pos, cv2.FONT_HERSHEY_DUPLEX, 0.50,
192 label_color, 1, cv2.LINE_AA)
193
194
195def get_source_encoding_int(video_capture):
196 return int(video_capture.get(cv2.CAP_PROP_FOURCC))
Raviv Shalev97ddc062021-12-07 15:18:09 +0200197
198
199def crop_bounding_box_object(input_frame: np.ndarray, x_min: float, y_min: float, x_max: float, y_max: float):
200 """
201 Creates a cropped image based on x and y coordinates.
202
203 Args:
204 input_frame: Image to crop
205 x_min, y_min, x_max, y_max: Coordinates of the bounding box
206
207 Returns:
208 Cropped image
209 """
210 # Adding +1 to exclude the bounding box pixels.
211 cropped_image = input_frame[int(y_min) + 1:int(y_max), int(x_min) + 1:int(x_max)]
212 return cropped_image