blob: 4f06eb184daa59ed9b72988d2ef8ea3dac7b2db1 [file] [log] [blame]
Jakub Sujak433a5952020-06-17 15:35:03 +01001# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4"""
5Object detection demo that takes a video file, runs inference on each frame producing
6bounding boxes and labels around detected objects, and saves the processed video.
7"""
8
9import os
10import cv2
11import pyarmnn as ann
12from tqdm import tqdm
13from argparse import ArgumentParser
14
15from ssd import ssd_processing, ssd_resize_factor
16from yolo import yolo_processing, yolo_resize_factor
17from utils import create_video_writer, create_network, dict_labels, preprocess, execute_network, draw_bounding_boxes
18
19
20parser = ArgumentParser()
21parser.add_argument('--video_file_path', required=True, type=str,
22 help='Path to the video file to run object detection on')
23parser.add_argument('--model_file_path', required=True, type=str,
24 help='Path to the Object Detection model to use')
25parser.add_argument('--model_name', required=True, type=str,
26 help='The name of the model being used. Accepted options: ssd_mobilenet_v1, yolo_v3_tiny')
27parser.add_argument('--label_path', type=str,
28 help='Path to the labelset for the provided model file')
29parser.add_argument('--output_video_file_path', type=str,
30 help='Path to the output video file with detections added in')
31parser.add_argument('--preferred_backends', type=str, nargs='+', default=['CpuAcc', 'CpuRef'],
32 help='Takes the preferred backends in preference order, separated by whitespace, '
33 'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
34 'Defaults to [CpuAcc, CpuRef]')
35args = parser.parse_args()
36
37
38def init_video(video_path: str, output_path: str):
39 """
40 Creates a video capture object from a video file.
41
42 Args:
43 video_path: User-specified video file path.
44 output_path: Optional path to save the processed video.
45
46 Returns:
47 Video capture object to capture frames, video writer object to write processed
48 frames to file, plus total frame count of video source to iterate through.
49 """
50 if not os.path.exists(video_path):
51 raise FileNotFoundError(f'Video file not found for: {video_path}')
52 video = cv2.VideoCapture(video_path)
53 if not video.isOpened:
54 raise RuntimeError(f'Failed to open video capture from file: {video_path}')
55
56 video_writer = create_video_writer(video, video_path, output_path)
57 iter_frame_count = range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))
58 return video, video_writer, iter_frame_count
59
60
61def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding_info: tuple):
62 """
63 Gets model-specific information such as model labels and decoding and processing functions.
64 The user can include their own network and functions by adding another statement.
65
66 Args:
67 model_name: Name of type of supported model.
68 video: Video capture object, contains information about data source.
69 input_binding_info: Contains shape of model input layer, used for scaling bounding boxes.
70
71 Returns:
72 Model labels, decoding and processing functions.
73 """
74 if model_name == 'ssd_mobilenet_v1':
75 labels = os.path.join('ssd_labels.txt')
76 return labels, ssd_processing, ssd_resize_factor(video)
77 elif model_name == 'yolo_v3_tiny':
78 labels = os.path.join('yolo_labels.txt')
79 return labels, yolo_processing, yolo_resize_factor(video, input_binding_info)
80 else:
81 raise ValueError(f'{model_name} is not a valid model name')
82
83
84def main(args):
85 video, video_writer, frame_count = init_video(args.video_file_path, args.output_video_file_path)
86 net_id, runtime, input_binding_info, output_binding_info = create_network(args.model_file_path,
87 args.preferred_backends)
88 output_tensors = ann.make_output_tensors(output_binding_info)
89 labels, process_output, resize_factor = get_model_processing(args.model_name, video, input_binding_info)
90 labels = dict_labels(labels if args.label_path is None else args.label_path)
91
92 for _ in tqdm(frame_count, desc='Processing frames'):
93 frame_present, frame = video.read()
94 if not frame_present:
95 continue
96 input_tensors = preprocess(frame, input_binding_info)
97 inference_output = execute_network(input_tensors, output_tensors, runtime, net_id)
98 detections = process_output(inference_output)
99 draw_bounding_boxes(frame, detections, resize_factor, labels)
100 video_writer.write(frame)
101 print('Finished processing frames')
102 video.release(), video_writer.release()
103
104
105if __name__ == '__main__':
106 main(args)