blob: e12ff50548f7859dec2f6e98752e962a3958c5a0 [file] [log] [blame]
alexanderf42f5682021-07-16 11:30:56 +01001# Copyright © 2020-2021 Arm Ltd and Contributors. All rights reserved.
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +00002# SPDX-License-Identifier: MIT
3
4"""
5This file contains helper functions for reading video/image data and
6 pre/postprocessing of video/image data using OpenCV.
7"""
8
9import os
10
11import cv2
12import numpy as np
13
14import pyarmnn as ann
15
16
alexanderf42f5682021-07-16 11:30:56 +010017def preprocess(frame: np.ndarray, input_binding_info: tuple, is_normalised: bool):
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000018 """
19 Takes a frame, resizes, swaps channels and converts data type to match
20 model input layer. The converted frame is wrapped in a const tensor
21 and bound to the input tensor.
22
23 Args:
24 frame: Captured frame from video.
25 input_binding_info: Contains shape and data type of model input layer.
alexanderf42f5682021-07-16 11:30:56 +010026 is_normalised: if the input layer expects normalised data
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000027
28 Returns:
29 Input tensor.
30 """
31 # Swap channels and resize frame to model resolution
32 frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
33 resized_frame = resize_with_aspect_ratio(frame, input_binding_info)
34
35 # Expand dimensions and convert data type to match model input
Éanna Ó Catháin65d5d2d2021-08-20 14:41:38 +010036 if input_binding_info[1].GetDataType() == ann.DataType_Float32:
37 data_type = np.float32
alexanderf42f5682021-07-16 11:30:56 +010038 if is_normalised:
39 resized_frame = resized_frame.astype("float32")/255
Éanna Ó Catháin65d5d2d2021-08-20 14:41:38 +010040 else:
41 data_type = np.uint8
42
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000043 resized_frame = np.expand_dims(np.asarray(resized_frame, dtype=data_type), axis=0)
44 assert resized_frame.shape == tuple(input_binding_info[1].GetShape())
45
46 input_tensors = ann.make_input_tensors([input_binding_info], [resized_frame])
47 return input_tensors
48
49
Éanna Ó Catháin65d5d2d2021-08-20 14:41:38 +010050
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000051def resize_with_aspect_ratio(frame: np.ndarray, input_binding_info: tuple):
52 """
53 Resizes frame while maintaining aspect ratio, padding any empty space.
54
55 Args:
56 frame: Captured frame.
57 input_binding_info: Contains shape of model input layer.
58
59 Returns:
60 Frame resized to the size of model input layer.
61 """
62 aspect_ratio = frame.shape[1] / frame.shape[0]
63 model_height, model_width = list(input_binding_info[1].GetShape())[1:3]
64
65 if aspect_ratio >= 1.0:
66 new_height, new_width = int(model_width / aspect_ratio), model_width
67 b_padding, r_padding = model_height - new_height, 0
68 else:
69 new_height, new_width = model_height, int(model_height * aspect_ratio)
70 b_padding, r_padding = 0, model_width - new_width
71
72 # Resize and pad any empty space
73 frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
74 frame = cv2.copyMakeBorder(frame, top=0, bottom=b_padding, left=0, right=r_padding,
75 borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0])
76 return frame
77
78
79def create_video_writer(video: cv2.VideoCapture, video_path: str, output_path: str):
80 """
81 Creates a video writer object to write processed frames to file.
82
83 Args:
84 video: Video capture object, contains information about data source.
85 video_path: User-specified video file path.
86 output_path: Optional path to save the processed video.
87
88 Returns:
89 Video writer object.
90 """
91 _, ext = os.path.splitext(video_path)
92
93 if output_path is not None:
94 assert os.path.isdir(output_path)
95
96 i, filename = 0, os.path.join(output_path if output_path is not None else str(), f'object_detection_demo{ext}')
97 while os.path.exists(filename):
98 i += 1
99 filename = os.path.join(output_path if output_path is not None else str(), f'object_detection_demo({i}){ext}')
100
101 video_writer = cv2.VideoWriter(filename=filename,
102 fourcc=get_source_encoding_int(video),
103 fps=int(video.get(cv2.CAP_PROP_FPS)),
104 frameSize=(int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
105 int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))))
106 return video_writer
107
108
109def init_video_file_capture(video_path: str, output_path: str):
110 """
111 Creates a video capture object from a video file.
112
113 Args:
114 video_path: User-specified video file path.
115 output_path: Optional path to save the processed video.
116
117 Returns:
118 Video capture object to capture frames, video writer object to write processed
119 frames to file, plus total frame count of video source to iterate through.
120 """
121 if not os.path.exists(video_path):
122 raise FileNotFoundError(f'Video file not found for: {video_path}')
123 video = cv2.VideoCapture(video_path)
124 if not video.isOpened:
125 raise RuntimeError(f'Failed to open video capture from file: {video_path}')
126
127 video_writer = create_video_writer(video, video_path, output_path)
128 iter_frame_count = range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))
129 return video, video_writer, iter_frame_count
130
131
132def init_video_stream_capture(video_source: int):
133 """
134 Creates a video capture object from a device.
135
136 Args:
137 video_source: Device index used to read video stream.
138
139 Returns:
140 Video capture object used to capture frames from a video stream.
141 """
142 video = cv2.VideoCapture(video_source)
143 if not video.isOpened:
144 raise RuntimeError(f'Failed to open video capture for device with index: {video_source}')
145 print('Processing video stream. Press \'Esc\' key to exit the demo.')
146 return video
147
148
149def draw_bounding_boxes(frame: np.ndarray, detections: list, resize_factor, labels: dict):
150 """
151 Draws bounding boxes around detected objects and adds a label and confidence score.
152
153 Args:
154 frame: The original captured frame from video source.
155 detections: A list of detected objects in the form [class, [box positions], confidence].
156 resize_factor: Resizing factor to scale box coordinates to output frame size.
157 labels: Dictionary of labels and colors keyed on the classification index.
158 """
159 for detection in detections:
160 class_idx, box, confidence = [d for d in detection]
161 label, color = labels[class_idx][0].capitalize(), labels[class_idx][1]
162
163 # Obtain frame size and resized bounding box positions
164 frame_height, frame_width = frame.shape[:2]
165 x_min, y_min, x_max, y_max = [int(position * resize_factor) for position in box]
166
167 # Ensure box stays within the frame
168 x_min, y_min = max(0, x_min), max(0, y_min)
169 x_max, y_max = min(frame_width, x_max), min(frame_height, y_max)
170
171 # Draw bounding box around detected object
172 cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), color, 2)
173
174 # Create label for detected object class
175 label = f'{label} {confidence * 100:.1f}%'
176 label_color = (0, 0, 0) if sum(color)>200 else (255, 255, 255)
177
178 # Make sure label always stays on-screen
179 x_text, y_text = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, 1, 1)[0][:2]
180
181 lbl_box_xy_min = (x_min, y_min if y_min<25 else y_min - y_text)
182 lbl_box_xy_max = (x_min + int(0.55 * x_text), y_min + y_text if y_min<25 else y_min)
183 lbl_text_pos = (x_min + 5, y_min + 16 if y_min<25 else y_min - 5)
184
185 # Add label and confidence value
186 cv2.rectangle(frame, lbl_box_xy_min, lbl_box_xy_max, color, -1)
187 cv2.putText(frame, label, lbl_text_pos, cv2.FONT_HERSHEY_DUPLEX, 0.50,
188 label_color, 1, cv2.LINE_AA)
189
190
191def get_source_encoding_int(video_capture):
192 return int(video_capture.get(cv2.CAP_PROP_FOURCC))