Jakub Sujak | 433a595 | 2020-06-17 15:35:03 +0100 | [diff] [blame] | 1 | # Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
|
| 2 | # SPDX-License-Identifier: MIT
|
| 3 |
|
| 4 | """
|
| 5 | Object detection demo that takes a video file, runs inference on each frame producing
|
| 6 | bounding boxes and labels around detected objects, and saves the processed video.
|
| 7 | """
|
| 8 |
|
| 9 | import os
|
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 10 | import sys
|
| 11 | script_dir = os.path.dirname(__file__)
|
| 12 | sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
|
| 13 |
|
Jakub Sujak | 433a595 | 2020-06-17 15:35:03 +0100 | [diff] [blame] | 14 | import cv2
|
Jakub Sujak | 433a595 | 2020-06-17 15:35:03 +0100 | [diff] [blame] | 15 | from tqdm import tqdm
|
| 16 | from argparse import ArgumentParser
|
| 17 |
|
| 18 | from ssd import ssd_processing, ssd_resize_factor
|
| 19 | from yolo import yolo_processing, yolo_resize_factor
|
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 20 | from utils import dict_labels
|
| 21 | from cv_utils import init_video_file_capture, preprocess, draw_bounding_boxes
|
| 22 | from network_executor import ArmnnNetworkExecutor
|
Jakub Sujak | 433a595 | 2020-06-17 15:35:03 +0100 | [diff] [blame] | 23 |
|
| 24 |
|
| 25 | def get_model_processing(model_name: str, video: cv2.VideoCapture, input_binding_info: tuple):
|
| 26 | """
|
| 27 | Gets model-specific information such as model labels and decoding and processing functions.
|
| 28 | The user can include their own network and functions by adding another statement.
|
| 29 |
|
| 30 | Args:
|
| 31 | model_name: Name of type of supported model.
|
| 32 | video: Video capture object, contains information about data source.
|
| 33 | input_binding_info: Contains shape of model input layer, used for scaling bounding boxes.
|
| 34 |
|
| 35 | Returns:
|
| 36 | Model labels, decoding and processing functions.
|
| 37 | """
|
| 38 | if model_name == 'ssd_mobilenet_v1':
|
Jakub Sujak | 885cf8c | 2020-11-24 16:39:21 +0000 | [diff] [blame] | 39 | return ssd_processing, ssd_resize_factor(video)
|
Jakub Sujak | 433a595 | 2020-06-17 15:35:03 +0100 | [diff] [blame] | 40 | elif model_name == 'yolo_v3_tiny':
|
Jakub Sujak | 885cf8c | 2020-11-24 16:39:21 +0000 | [diff] [blame] | 41 | return yolo_processing, yolo_resize_factor(video, input_binding_info)
|
Jakub Sujak | 433a595 | 2020-06-17 15:35:03 +0100 | [diff] [blame] | 42 | else:
|
| 43 | raise ValueError(f'{model_name} is not a valid model name')
|
| 44 |
|
| 45 |
|
| 46 | def main(args):
|
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 47 | video, video_writer, frame_count = init_video_file_capture(args.video_file_path, args.output_video_file_path)
|
| 48 |
|
| 49 | executor = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends)
|
Jakub Sujak | 885cf8c | 2020-11-24 16:39:21 +0000 | [diff] [blame] | 50 | process_output, resize_factor = get_model_processing(args.model_name, video, executor.input_binding_info)
|
| 51 | labels = dict_labels(args.label_path, include_rgb=True)
|
Jakub Sujak | 433a595 | 2020-06-17 15:35:03 +0100 | [diff] [blame] | 52 |
|
| 53 | for _ in tqdm(frame_count, desc='Processing frames'):
|
| 54 | frame_present, frame = video.read()
|
| 55 | if not frame_present:
|
| 56 | continue
|
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 57 | input_tensors = preprocess(frame, executor.input_binding_info)
|
| 58 | output_result = executor.run(input_tensors)
|
| 59 | detections = process_output(output_result)
|
Jakub Sujak | 433a595 | 2020-06-17 15:35:03 +0100 | [diff] [blame] | 60 | draw_bounding_boxes(frame, detections, resize_factor, labels)
|
| 61 | video_writer.write(frame)
|
| 62 | print('Finished processing frames')
|
| 63 | video.release(), video_writer.release()
|
| 64 |
|
| 65 |
|
| 66 | if __name__ == '__main__':
|
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 67 | parser = ArgumentParser()
|
| 68 | parser.add_argument('--video_file_path', required=True, type=str,
|
| 69 | help='Path to the video file to run object detection on')
|
| 70 | parser.add_argument('--model_file_path', required=True, type=str,
|
| 71 | help='Path to the Object Detection model to use')
|
| 72 | parser.add_argument('--model_name', required=True, type=str,
|
| 73 | help='The name of the model being used. Accepted options: ssd_mobilenet_v1, yolo_v3_tiny')
|
Jakub Sujak | 885cf8c | 2020-11-24 16:39:21 +0000 | [diff] [blame] | 74 | parser.add_argument('--label_path', required=True, type=str,
|
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 75 | help='Path to the labelset for the provided model file')
|
| 76 | parser.add_argument('--output_video_file_path', type=str,
|
| 77 | help='Path to the output video file with detections added in')
|
| 78 | parser.add_argument('--preferred_backends', type=str, nargs='+', default=['CpuAcc', 'CpuRef'],
|
| 79 | help='Takes the preferred backends in preference order, separated by whitespace, '
|
| 80 | 'for example: CpuAcc GpuAcc CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]. '
|
| 81 | 'Defaults to [CpuAcc, CpuRef]')
|
| 82 | args = parser.parse_args()
|
Jakub Sujak | 433a595 | 2020-06-17 15:35:03 +0100 | [diff] [blame] | 83 | main(args)
|