blob: 94dc6c8b13d666f5dc575f1900fd42f1a7f70455 [file] [log] [blame]
Jakub Sujak433a5952020-06-17 15:35:03 +01001# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4"""
5Object detection demo that takes a video stream from a device, runs inference
6on each frame producing bounding boxes and labels around detected objects,
7and displays a window with the latest processed frame.
8"""
9
10import os
11import cv2
12import pyarmnn as ann
13from tqdm import tqdm
14from argparse import ArgumentParser
15
16from ssd import ssd_processing, ssd_resize_factor
17from yolo import yolo_processing, yolo_resize_factor
18from utils import create_network, dict_labels, preprocess, execute_network, draw_bounding_boxes
19
20
21parser = ArgumentParser()
22parser.add_argument('--video_source', type=int, default=0,
23 help='Device index to access video stream. Defaults to primary device camera at index 0')
24parser.add_argument('--model_file_path', required=True, type=str,
25 help='Path to the Object Detection model to use')
26parser.add_argument('--model_name', required=True, type=str,
27 help='The name of the model being used. Accepted options: ssd_mobilenet_v1, yolo_v3_tiny')
28parser.add_argument('--label_path', type=str,
29 help='Path to the labelset for the provided model file')
30parser.add_argument('--preferred_backends', type=str, nargs='+', default=['CpuAcc', 'CpuRef'],
31 help='Takes the preferred backends in preference order, separated by whitespace, '
32 'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
33 'Defaults to [CpuAcc, CpuRef]')
34args = parser.parse_args()
35
36
37def init_video(video_source: int):
38 """
39 Creates a video capture object from a device.
40
41 Args:
42 video_source: Device index used to read video stream.
43
44 Returns:
45 Video capture object used to capture frames from a video stream.
46 """
47 video = cv2.VideoCapture(video_source)
48 if not video.isOpened:
49 raise RuntimeError(f'Failed to open video capture for device with index: {video_source}')
50 print('Processing video stream. Press \'Esc\' key to exit the demo.')
51 return video
52
53
54def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding_info: tuple):
55 """
56 Gets model-specific information such as model labels and decoding and processing functions.
57 The user can include their own network and functions by adding another statement.
58
59 Args:
60 model_name: Name of type of supported model.
61 video: Video capture object, contains information about data source.
62 input_binding_info: Contains shape of model input layer, used for scaling bounding boxes.
63
64 Returns:
65 Model labels, decoding and processing functions.
66 """
67 if model_name == 'ssd_mobilenet_v1':
68 labels = os.path.join('ssd_labels.txt')
69 return labels, ssd_processing, ssd_resize_factor(video)
70 elif model_name == 'yolo_v3_tiny':
71 labels = os.path.join('yolo_labels.txt')
72 return labels, yolo_processing, yolo_resize_factor(video, input_binding_info)
73 else:
74 raise ValueError(f'{model_name} is not a valid model name')
75
76
77def main(args):
78 video = init_video(args.video_source)
79 net_id, runtime, input_binding_info, output_binding_info = create_network(args.model_file_path,
80 args.preferred_backends)
81 output_tensors = ann.make_output_tensors(output_binding_info)
82 labels, process_output, resize_factor = get_model_processing(args.model_name, video, input_binding_info)
83 labels = dict_labels(labels if args.label_path is None else args.label_path)
84
85 while True:
86 frame_present, frame = video.read()
87 frame = cv2.flip(frame, 1) # Horizontally flip the frame
88 if not frame_present:
89 raise RuntimeError('Error reading frame from video stream')
90 input_tensors = preprocess(frame, input_binding_info)
91 inference_output = execute_network(input_tensors, output_tensors, runtime, net_id)
92 detections = process_output(inference_output)
93 draw_bounding_boxes(frame, detections, resize_factor, labels)
94 cv2.imshow('PyArmNN Object Detection Demo', frame)
95 if cv2.waitKey(1) == 27:
96 print('\nExit key activated. Closing video...')
97 break
98 video.release(), cv2.destroyAllWindows()
99
100
101if __name__ == '__main__':
102 main(args)