Jakub Sujak | 433a595 | 2020-06-17 15:35:03 +0100 | [diff] [blame] | 1 | # Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
|
| 2 | # SPDX-License-Identifier: MIT
|
| 3 |
|
| 4 | """
|
| 5 | Object detection demo that takes a video file, runs inference on each frame producing
|
| 6 | bounding boxes and labels around detected objects, and saves the processed video.
|
| 7 | """
|
| 8 |
|
| 9 | import os
|
| 10 | import cv2
|
| 11 | import pyarmnn as ann
|
| 12 | from tqdm import tqdm
|
| 13 | from argparse import ArgumentParser
|
| 14 |
|
| 15 | from ssd import ssd_processing, ssd_resize_factor
|
| 16 | from yolo import yolo_processing, yolo_resize_factor
|
| 17 | from utils import create_video_writer, create_network, dict_labels, preprocess, execute_network, draw_bounding_boxes
|
| 18 |
|
| 19 |
|
| 20 | parser = ArgumentParser()
|
| 21 | parser.add_argument('--video_file_path', required=True, type=str,
|
| 22 | help='Path to the video file to run object detection on')
|
| 23 | parser.add_argument('--model_file_path', required=True, type=str,
|
| 24 | help='Path to the Object Detection model to use')
|
| 25 | parser.add_argument('--model_name', required=True, type=str,
|
| 26 | help='The name of the model being used. Accepted options: ssd_mobilenet_v1, yolo_v3_tiny')
|
| 27 | parser.add_argument('--label_path', type=str,
|
| 28 | help='Path to the labelset for the provided model file')
|
| 29 | parser.add_argument('--output_video_file_path', type=str,
|
| 30 | help='Path to the output video file with detections added in')
|
| 31 | parser.add_argument('--preferred_backends', type=str, nargs='+', default=['CpuAcc', 'CpuRef'],
|
| 32 | help='Takes the preferred backends in preference order, separated by whitespace, '
|
| 33 | 'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
|
| 34 | 'Defaults to [CpuAcc, CpuRef]')
|
| 35 | args = parser.parse_args()
|
| 36 |
|
| 37 |
|
| 38 | def init_video(video_path: str, output_path: str):
|
| 39 | """
|
| 40 | Creates a video capture object from a video file.
|
| 41 |
|
| 42 | Args:
|
| 43 | video_path: User-specified video file path.
|
| 44 | output_path: Optional path to save the processed video.
|
| 45 |
|
| 46 | Returns:
|
| 47 | Video capture object to capture frames, video writer object to write processed
|
| 48 | frames to file, plus total frame count of video source to iterate through.
|
| 49 | """
|
| 50 | if not os.path.exists(video_path):
|
| 51 | raise FileNotFoundError(f'Video file not found for: {video_path}')
|
| 52 | video = cv2.VideoCapture(video_path)
|
| 53 | if not video.isOpened:
|
| 54 | raise RuntimeError(f'Failed to open video capture from file: {video_path}')
|
| 55 |
|
| 56 | video_writer = create_video_writer(video, video_path, output_path)
|
| 57 | iter_frame_count = range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))
|
| 58 | return video, video_writer, iter_frame_count
|
| 59 |
|
| 60 |
|
| 61 | def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding_info: tuple):
|
| 62 | """
|
| 63 | Gets model-specific information such as model labels and decoding and processing functions.
|
| 64 | The user can include their own network and functions by adding another statement.
|
| 65 |
|
| 66 | Args:
|
| 67 | model_name: Name of type of supported model.
|
| 68 | video: Video capture object, contains information about data source.
|
| 69 | input_binding_info: Contains shape of model input layer, used for scaling bounding boxes.
|
| 70 |
|
| 71 | Returns:
|
| 72 | Model labels, decoding and processing functions.
|
| 73 | """
|
| 74 | if model_name == 'ssd_mobilenet_v1':
|
| 75 | labels = os.path.join('ssd_labels.txt')
|
| 76 | return labels, ssd_processing, ssd_resize_factor(video)
|
| 77 | elif model_name == 'yolo_v3_tiny':
|
| 78 | labels = os.path.join('yolo_labels.txt')
|
| 79 | return labels, yolo_processing, yolo_resize_factor(video, input_binding_info)
|
| 80 | else:
|
| 81 | raise ValueError(f'{model_name} is not a valid model name')
|
| 82 |
|
| 83 |
|
| 84 | def main(args):
|
| 85 | video, video_writer, frame_count = init_video(args.video_file_path, args.output_video_file_path)
|
| 86 | net_id, runtime, input_binding_info, output_binding_info = create_network(args.model_file_path,
|
| 87 | args.preferred_backends)
|
| 88 | output_tensors = ann.make_output_tensors(output_binding_info)
|
| 89 | labels, process_output, resize_factor = get_model_processing(args.model_name, video, input_binding_info)
|
| 90 | labels = dict_labels(labels if args.label_path is None else args.label_path)
|
| 91 |
|
| 92 | for _ in tqdm(frame_count, desc='Processing frames'):
|
| 93 | frame_present, frame = video.read()
|
| 94 | if not frame_present:
|
| 95 | continue
|
| 96 | input_tensors = preprocess(frame, input_binding_info)
|
| 97 | inference_output = execute_network(input_tensors, output_tensors, runtime, net_id)
|
| 98 | detections = process_output(inference_output)
|
| 99 | draw_bounding_boxes(frame, detections, resize_factor, labels)
|
| 100 | video_writer.write(frame)
|
| 101 | print('Finished processing frames')
|
| 102 | video.release(), video_writer.release()
|
| 103 |
|
| 104 |
|
| 105 | if __name__ == '__main__':
|
| 106 | main(args)
|