Jakub Sujak | 433a595 | 2020-06-17 15:35:03 +0100 | [diff] [blame] | 1 | # Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
|
| 2 | # SPDX-License-Identifier: MIT
|
| 3 |
|
| 4 | """
|
| 5 | Object detection demo that takes a video stream from a device, runs inference
|
| 6 | on each frame producing bounding boxes and labels around detected objects,
|
| 7 | and displays a window with the latest processed frame.
|
| 8 | """
|
| 9 |
|
| 10 | import os
|
| 11 | import cv2
|
| 12 | import pyarmnn as ann
|
| 13 | from tqdm import tqdm
|
| 14 | from argparse import ArgumentParser
|
| 15 |
|
| 16 | from ssd import ssd_processing, ssd_resize_factor
|
| 17 | from yolo import yolo_processing, yolo_resize_factor
|
| 18 | from utils import create_network, dict_labels, preprocess, execute_network, draw_bounding_boxes
|
| 19 |
|
| 20 |
|
| 21 | parser = ArgumentParser()
|
| 22 | parser.add_argument('--video_source', type=int, default=0,
|
| 23 | help='Device index to access video stream. Defaults to primary device camera at index 0')
|
| 24 | parser.add_argument('--model_file_path', required=True, type=str,
|
| 25 | help='Path to the Object Detection model to use')
|
| 26 | parser.add_argument('--model_name', required=True, type=str,
|
| 27 | help='The name of the model being used. Accepted options: ssd_mobilenet_v1, yolo_v3_tiny')
|
| 28 | parser.add_argument('--label_path', type=str,
|
| 29 | help='Path to the labelset for the provided model file')
|
| 30 | parser.add_argument('--preferred_backends', type=str, nargs='+', default=['CpuAcc', 'CpuRef'],
|
| 31 | help='Takes the preferred backends in preference order, separated by whitespace, '
|
| 32 | 'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
|
| 33 | 'Defaults to [CpuAcc, CpuRef]')
|
| 34 | args = parser.parse_args()
|
| 35 |
|
| 36 |
|
| 37 | def init_video(video_source: int):
|
| 38 | """
|
| 39 | Creates a video capture object from a device.
|
| 40 |
|
| 41 | Args:
|
| 42 | video_source: Device index used to read video stream.
|
| 43 |
|
| 44 | Returns:
|
| 45 | Video capture object used to capture frames from a video stream.
|
| 46 | """
|
| 47 | video = cv2.VideoCapture(video_source)
|
| 48 | if not video.isOpened:
|
| 49 | raise RuntimeError(f'Failed to open video capture for device with index: {video_source}')
|
| 50 | print('Processing video stream. Press \'Esc\' key to exit the demo.')
|
| 51 | return video
|
| 52 |
|
| 53 |
|
| 54 | def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding_info: tuple):
|
| 55 | """
|
| 56 | Gets model-specific information such as model labels and decoding and processing functions.
|
| 57 | The user can include their own network and functions by adding another statement.
|
| 58 |
|
| 59 | Args:
|
| 60 | model_name: Name of type of supported model.
|
| 61 | video: Video capture object, contains information about data source.
|
| 62 | input_binding_info: Contains shape of model input layer, used for scaling bounding boxes.
|
| 63 |
|
| 64 | Returns:
|
| 65 | Model labels, decoding and processing functions.
|
| 66 | """
|
| 67 | if model_name == 'ssd_mobilenet_v1':
|
| 68 | labels = os.path.join('ssd_labels.txt')
|
| 69 | return labels, ssd_processing, ssd_resize_factor(video)
|
| 70 | elif model_name == 'yolo_v3_tiny':
|
| 71 | labels = os.path.join('yolo_labels.txt')
|
| 72 | return labels, yolo_processing, yolo_resize_factor(video, input_binding_info)
|
| 73 | else:
|
| 74 | raise ValueError(f'{model_name} is not a valid model name')
|
| 75 |
|
| 76 |
|
| 77 | def main(args):
|
| 78 | video = init_video(args.video_source)
|
| 79 | net_id, runtime, input_binding_info, output_binding_info = create_network(args.model_file_path,
|
| 80 | args.preferred_backends)
|
| 81 | output_tensors = ann.make_output_tensors(output_binding_info)
|
| 82 | labels, process_output, resize_factor = get_model_processing(args.model_name, video, input_binding_info)
|
| 83 | labels = dict_labels(labels if args.label_path is None else args.label_path)
|
| 84 |
|
| 85 | while True:
|
| 86 | frame_present, frame = video.read()
|
| 87 | frame = cv2.flip(frame, 1) # Horizontally flip the frame
|
| 88 | if not frame_present:
|
| 89 | raise RuntimeError('Error reading frame from video stream')
|
| 90 | input_tensors = preprocess(frame, input_binding_info)
|
| 91 | inference_output = execute_network(input_tensors, output_tensors, runtime, net_id)
|
| 92 | detections = process_output(inference_output)
|
| 93 | draw_bounding_boxes(frame, detections, resize_factor, labels)
|
| 94 | cv2.imshow('PyArmNN Object Detection Demo', frame)
|
| 95 | if cv2.waitKey(1) == 27:
|
| 96 | print('\nExit key activated. Closing video...')
|
| 97 | break
|
| 98 | video.release(), cv2.destroyAllWindows()
|
| 99 |
|
| 100 |
|
| 101 | if __name__ == '__main__':
|
| 102 | main(args)
|