Raviv Shalev | 97ddc06 | 2021-12-07 15:18:09 +0200 | [diff] [blame] | 1 | # Copyright © 2022 Arm Ltd and Contributors. All rights reserved. |
| 2 | # SPDX-License-Identifier: MIT |
| 3 | |
| 4 | import numpy as np |
| 5 | import urllib.request |
| 6 | import cv2 |
| 7 | import network_executor_tflite |
| 8 | import cv_utils |
| 9 | |
| 10 | |
| 11 | def style_transfer_postprocess(preprocessed_frame: np.ndarray, image_shape: tuple): |
| 12 | """ |
| 13 | Resizes the output frame of style transfer network and changes the color back to original configuration |
| 14 | |
| 15 | Args: |
| 16 | preprocessed_frame: A preprocessed frame after style transfer. |
| 17 | image_shape: Contains shape of the original frame before preprocessing. |
| 18 | |
| 19 | Returns: |
| 20 | Resizing factor to scale coordinates according to image_shape. |
| 21 | """ |
| 22 | |
| 23 | postprocessed_frame = np.squeeze(preprocessed_frame, axis=0) |
| 24 | # select original height and width from image_shape |
| 25 | frame_height = image_shape[0] |
| 26 | frame_width = image_shape[1] |
| 27 | postprocessed_frame = cv2.resize(postprocessed_frame, (frame_width, frame_height)).astype("float32") * 255 |
| 28 | postprocessed_frame = cv2.cvtColor(postprocessed_frame, cv2.COLOR_RGB2BGR) |
| 29 | |
| 30 | return postprocessed_frame |
| 31 | |
| 32 | |
| 33 | def create_stylized_detection(style_transfer_executor, style_transfer_class, frame: np.ndarray, |
| 34 | detections: list, resize_factor, labels: dict): |
| 35 | """ |
| 36 | Perform style transfer on a detected class in a frame |
| 37 | |
| 38 | Args: |
| 39 | style_transfer_executor: The style transfer executor |
| 40 | style_transfer_class: The class detected to change its style |
| 41 | frame: The original captured frame from video source. |
| 42 | detections: A list of detected objects in the form [class, [box positions], confidence]. |
| 43 | resize_factor: Resizing factor to scale box coordinates to output frame size. |
| 44 | labels: Dictionary of labels and colors keyed on the classification index. |
| 45 | """ |
| 46 | for detection in detections: |
| 47 | class_idx, box, confidence = [d for d in detection] |
| 48 | label = labels[class_idx][0] |
| 49 | if label.lower() == style_transfer_class.lower(): |
| 50 | # Obtain frame size and resized bounding box positions |
| 51 | frame_height, frame_width = frame.shape[:2] |
| 52 | x_min, y_min, x_max, y_max = [int(position * resize_factor) for position in box] |
| 53 | |
| 54 | # Ensure box stays within the frame |
| 55 | x_min, y_min = max(0, x_min), max(0, y_min) |
| 56 | x_max, y_max = min(frame_width, x_max), min(frame_height, y_max) |
| 57 | |
| 58 | # Crop only the detected object |
| 59 | cropped_frame = cv_utils.crop_bounding_box_object(frame, x_min, y_min, x_max, y_max) |
| 60 | |
| 61 | # Run style_transfer on preprocessed_frame |
| 62 | stylized_frame = style_transfer_executor.run_style_transfer(cropped_frame) |
| 63 | |
| 64 | # Paste stylized_frame on the original frame in the correct place |
| 65 | frame[int(y_min)+1:int(y_max), int(x_min)+1:int(x_max)] = stylized_frame |
| 66 | |
| 67 | return frame |
| 68 | |
| 69 | |
| 70 | class StyleTransfer: |
| 71 | |
| 72 | def __init__(self, style_predict_model_path: str, style_transfer_model_path: str, |
| 73 | style_image: np.ndarray, backends: list, delegate_path: str): |
| 74 | """ |
| 75 | Creates an inference executor for style predict network, style transfer network, |
| 76 | list of backends and a style image. |
| 77 | |
| 78 | Args: |
| 79 | style_predict_model_path: model which is used to create a style bottleneck |
| 80 | style_transfer_model_path: model which is used to create stylized frames |
| 81 | style_image: an image to create the style bottleneck |
| 82 | backends: List of backends to optimize network. |
| 83 | delegate_path: tflite delegate file path (.so). |
| 84 | """ |
| 85 | |
| 86 | self.style_predict_executor = network_executor_tflite.TFLiteNetworkExecutor(style_predict_model_path, backends, |
| 87 | delegate_path) |
| 88 | self.style_transfer_executor = network_executor_tflite.TFLiteNetworkExecutor(style_transfer_model_path, |
| 89 | backends, |
| 90 | delegate_path) |
| 91 | self.style_bottleneck = self.run_style_predict(style_image) |
| 92 | |
| 93 | def get_style_predict_executor_shape(self): |
| 94 | """ |
| 95 | Get the input shape of the initiated network. |
| 96 | |
| 97 | Returns: |
| 98 | tuple: The Shape of the network input. |
| 99 | """ |
| 100 | return self.style_predict_executor.get_shape() |
| 101 | |
| 102 | # Function to run create a style_bottleneck using preprocessed style image. |
| 103 | def run_style_predict(self, style_image): |
| 104 | """ |
| 105 | Creates bottleneck tensor for a given style image. |
| 106 | |
| 107 | Args: |
| 108 | style_image: an image to create the style bottleneck |
| 109 | |
| 110 | Returns: |
| 111 | style bottleneck tensor |
| 112 | """ |
| 113 | # The style image has to be preprocessed to (1, 256, 256, 3) |
| 114 | preprocessed_style_image = cv_utils.preprocess(style_image, self.style_predict_executor.get_data_type(), |
| 115 | self.style_predict_executor.get_shape(), True, keep_aspect_ratio=False) |
| 116 | # output[0] is the style bottleneck tensor |
| 117 | style_bottleneck = self.style_predict_executor.run([preprocessed_style_image])[0] |
| 118 | |
| 119 | return style_bottleneck |
| 120 | |
| 121 | # Run style transform on preprocessed style image |
| 122 | def run_style_transfer(self, content_image): |
| 123 | """ |
| 124 | Runs inference for given content_image and style bottleneck to create a stylized image. |
| 125 | |
| 126 | Args: |
| 127 | content_image:a content image to stylize |
| 128 | """ |
| 129 | # The content image has to be preprocessed to (1, 384, 384, 3) |
| 130 | preprocessed_style_image = cv_utils.preprocess(content_image, np.float32, |
| 131 | self.style_transfer_executor.get_shape(), True, keep_aspect_ratio=False) |
| 132 | |
| 133 | # Transform content image. output[0] is the stylized image |
| 134 | stylized_image = self.style_transfer_executor.run([preprocessed_style_image, self.style_bottleneck])[0] |
| 135 | |
| 136 | post_stylized_image = style_transfer_postprocess(stylized_image, content_image.shape) |
| 137 | |
| 138 | return post_stylized_image |