Raul Farkas | 1c54ac1 | 2023-04-26 07:49:15 +0100 | [diff] [blame] | 1 | # SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com> |
| 2 | # |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | # |
| 5 | # Licensed under the Apache License, Version 2.0 (the License); you may |
| 6 | # not use this file except in compliance with the License. |
| 7 | # You may obtain a copy of the License at |
| 8 | # |
| 9 | # www.apache.org/licenses/LICENSE-2.0 |
| 10 | # |
| 11 | # Unless required by applicable law or agreed to in writing, software |
| 12 | # distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| 13 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | # See the License for the specific language governing permissions and |
| 15 | # limitations under the License. |
| 16 | # |
| 17 | # Description: |
| 18 | # Contains various utility functions used across the codebase. |
| 19 | from __future__ import annotations |
| 20 | |
| 21 | import collections |
| 22 | import inspect |
| 23 | |
| 24 | |
| 25 | def progress_print( |
| 26 | enabled: bool, |
| 27 | message: str, |
| 28 | progress_counter: int = -1, |
Rickard Bolin | 8da4f95 | 2023-06-14 16:05:47 +0000 | [diff] [blame] | 29 | progress_total: int | collections.abc.Sized = 0, |
Raul Farkas | 1c54ac1 | 2023-04-26 07:49:15 +0100 | [diff] [blame] | 30 | progress_granularity: float = 0.20, |
| 31 | ): |
| 32 | """Print progress information. |
| 33 | |
| 34 | :param enabled: boolean indicating whether message should be printed. |
| 35 | :param message: message to be printed |
| 36 | :param progress_counter: the value of the incremental counter that indicates the progress |
| 37 | :param progress_total: integer value or sized data structure to use to extract the total number of elements that |
| 38 | progress is measured against |
| 39 | :param progress_granularity: floating point percentage indicating how often progress information should be printed |
| 40 | :param enable_context: boolean used to indicate whether context information should be printed with the message |
| 41 | |
| 42 | Example |
| 43 | ------- |
| 44 | def example_function(verbose_progress: bool = True): |
| 45 | a_list = [x for x in range(101)] |
| 46 | for index, value in a: |
| 47 | progress_print(verbose_progress, |
| 48 | message="Processing", |
| 49 | progress_counter=index, |
| 50 | progress_total=a_list, |
| 51 | progress_granulrity=0.25, |
| 52 | enable_context=True) |
| 53 | |
| 54 | **Output** |
| 55 | Processing 0/100 |
| 56 | Processing 25/100 |
| 57 | Processing 50/100 |
| 58 | Processing 75/100 |
| 59 | Processing 100/100 |
| 60 | """ |
| 61 | if not enabled: |
| 62 | return |
| 63 | |
| 64 | context_str = "" |
| 65 | # Get calling function name |
| 66 | context_str = inspect.stack()[1].function |
| 67 | context_str += ": " if message else "" |
| 68 | display_total = progress_total |
| 69 | # If a sized collection is provided, extract its size to use as progress total |
Rickard Bolin | 8da4f95 | 2023-06-14 16:05:47 +0000 | [diff] [blame] | 70 | if isinstance(progress_total, collections.abc.Sized): |
Raul Farkas | 1c54ac1 | 2023-04-26 07:49:15 +0100 | [diff] [blame] | 71 | progress_total = len(progress_total) |
| 72 | display_total = progress_total - 1 |
| 73 | |
| 74 | # Print progress information with "counter/total" information |
| 75 | if progress_counter > -1 and progress_total > 0 and 0 < progress_granularity < 1: |
| 76 | # Extract progress frequency and ensure it is not equal to 0 (avoid zero division) |
| 77 | progress_frequency = int(progress_total * progress_granularity) |
| 78 | progress_frequency = progress_frequency if progress_frequency else 1 |
| 79 | # Check whether information should be printed based on computed progress frequency |
| 80 | if ( |
| 81 | progress_counter % progress_frequency == 0 and progress_counter <= progress_total - progress_frequency |
| 82 | ) or progress_counter == display_total: |
| 83 | print(f"{context_str}{message} {progress_counter}/{display_total}") |
| 84 | return |
| 85 | |
| 86 | print(f"{context_str}{message}") |
Raul Farkas | 3b64f06 | 2023-05-16 17:18:31 +0100 | [diff] [blame] | 87 | |
| 88 | |
| 89 | def calc_resize_factor(ifm_width: int, stride_x: int) -> tuple[int, int]: |
| 90 | """Compute resize factor for strided Conv2D optimization.""" |
| 91 | # Define strides that are supported by HW |
| 92 | hw_supported_strides = (2, 3) |
| 93 | resize_factor = stride_x |
| 94 | |
| 95 | if ifm_width % resize_factor != 0: |
| 96 | # In case it is not divisible, check if the resize factor is |
| 97 | # divisible by any of the hw_supported_strides. If it is, re-compute |
| 98 | # the resize factor to be the value that leads us to |
Johan Alfven | afb56ae | 2023-10-27 13:08:21 +0200 | [diff] [blame] | 99 | # reach a hw supported stride. The IFM width needs to be divisible by the new resize factor. |
Raul Farkas | 3b64f06 | 2023-05-16 17:18:31 +0100 | [diff] [blame] | 100 | # E.g.: IFM width = 133, stride = 14, filter width = 7 can be |
| 101 | # optimised to IFM width = 19, stride = 2, filter width = 7 using |
| 102 | # a resize factor of 7. The final stride is 2 which is |
| 103 | # supported by the hardware. |
| 104 | |
| 105 | # Filter strides that can be obtained from current stride |
| 106 | divisible_strides = (x for x in hw_supported_strides if resize_factor % x == 0) |
| 107 | # Remove strides that are not IFM width divisors |
| 108 | divisor_strides = (x for x in divisible_strides if ifm_width % (stride_x // x) == 0) |
| 109 | # Compute new resize factor based on chosen stride |
| 110 | new_resize_factor = resize_factor // next(divisor_strides, 1) |
| 111 | resize_factor = new_resize_factor if resize_factor != new_resize_factor else 1 |
| 112 | |
| 113 | optimised_stride = stride_x // resize_factor |
| 114 | |
| 115 | return resize_factor, optimised_stride |