| # SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com> |
| # |
| # SPDX-License-Identifier: Apache-2.0 |
| # |
| # Licensed under the Apache License, Version 2.0 (the License); you may |
| # not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| # Description: |
| # Contains various utility functions used across the codebase. |
| from __future__ import annotations |
| |
| import collections |
| import inspect |
| |
| |
| def progress_print( |
| enabled: bool, |
| message: str, |
| progress_counter: int = -1, |
| progress_total: int | collections.abc.Sized = 0, |
| progress_granularity: float = 0.20, |
| ): |
| """Print progress information. |
| |
| :param enabled: boolean indicating whether message should be printed. |
| :param message: message to be printed |
| :param progress_counter: the value of the incremental counter that indicates the progress |
| :param progress_total: integer value or sized data structure to use to extract the total number of elements that |
| progress is measured against |
| :param progress_granularity: floating point percentage indicating how often progress information should be printed |
| :param enable_context: boolean used to indicate whether context information should be printed with the message |
| |
| Example |
| ------- |
| def example_function(verbose_progress: bool = True): |
| a_list = [x for x in range(101)] |
| for index, value in a: |
| progress_print(verbose_progress, |
| message="Processing", |
| progress_counter=index, |
| progress_total=a_list, |
| progress_granulrity=0.25, |
| enable_context=True) |
| |
| **Output** |
| Processing 0/100 |
| Processing 25/100 |
| Processing 50/100 |
| Processing 75/100 |
| Processing 100/100 |
| """ |
| if not enabled: |
| return |
| |
| context_str = "" |
| # Get calling function name |
| context_str = inspect.stack()[1].function |
| context_str += ": " if message else "" |
| display_total = progress_total |
| # If a sized collection is provided, extract its size to use as progress total |
| if isinstance(progress_total, collections.abc.Sized): |
| progress_total = len(progress_total) |
| display_total = progress_total - 1 |
| |
| # Print progress information with "counter/total" information |
| if progress_counter > -1 and progress_total > 0 and 0 < progress_granularity < 1: |
| # Extract progress frequency and ensure it is not equal to 0 (avoid zero division) |
| progress_frequency = int(progress_total * progress_granularity) |
| progress_frequency = progress_frequency if progress_frequency else 1 |
| # Check whether information should be printed based on computed progress frequency |
| if ( |
| progress_counter % progress_frequency == 0 and progress_counter <= progress_total - progress_frequency |
| ) or progress_counter == display_total: |
| print(f"{context_str}{message} {progress_counter}/{display_total}") |
| return |
| |
| print(f"{context_str}{message}") |
| |
| |
| def calc_resize_factor(ifm_width: int, stride_x: int) -> tuple[int, int]: |
| """Compute resize factor for strided Conv2D optimization.""" |
| # Define strides that are supported by HW |
| hw_supported_strides = (2, 3) |
| resize_factor = stride_x |
| |
| if ifm_width % resize_factor != 0: |
| # In case it is not divisible, check if the resize factor is |
| # divisible by any of the hw_supported_strides. If it is, re-compute |
| # the resize factor to be the value that leads us to |
| # reach a hw supported stride. The IFM width needs to be divisible by the new resize factor. |
| # E.g.: IFM width = 133, stride = 14, filter width = 7 can be |
| # optimised to IFM width = 19, stride = 2, filter width = 7 using |
| # a resize factor of 7. The final stride is 2 which is |
| # supported by the hardware. |
| |
| # Filter strides that can be obtained from current stride |
| divisible_strides = (x for x in hw_supported_strides if resize_factor % x == 0) |
| # Remove strides that are not IFM width divisors |
| divisor_strides = (x for x in divisible_strides if ifm_width % (stride_x // x) == 0) |
| # Compute new resize factor based on chosen stride |
| new_resize_factor = resize_factor // next(divisor_strides, 1) |
| resize_factor = new_resize_factor if resize_factor != new_resize_factor else 1 |
| |
| optimised_stride = stride_x // resize_factor |
| |
| return resize_factor, optimised_stride |