Tim Hall | 3b1578e | 2023-01-13 17:57:25 +0000 | [diff] [blame] | 1 | # SPDX-FileCopyrightText: Copyright 2020-2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com> |
Louis Verhaard | 0b8268a | 2020-08-05 16:11:29 +0200 | [diff] [blame] | 2 | # |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | # |
| 5 | # Licensed under the Apache License, Version 2.0 (the License); you may |
| 6 | # not use this file except in compliance with the License. |
| 7 | # You may obtain a copy of the License at |
| 8 | # |
| 9 | # www.apache.org/licenses/LICENSE-2.0 |
| 10 | # |
| 11 | # Unless required by applicable law or agreed to in writing, software |
| 12 | # distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| 13 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | # See the License for the specific language governing permissions and |
| 15 | # limitations under the License. |
Rickard Bolin | bc6ee58 | 2022-11-04 08:24:29 +0000 | [diff] [blame] | 16 | # |
Louis Verhaard | 0b8268a | 2020-08-05 16:11:29 +0200 | [diff] [blame] | 17 | # Description: |
| 18 | # Functionality for lookup table support. |
| 19 | import uuid |
Louis Verhaard | 0b8268a | 2020-08-05 16:11:29 +0200 | [diff] [blame] | 20 | |
Louis Verhaard | b9fc33c | 2020-08-13 11:47:36 +0200 | [diff] [blame] | 21 | import numpy as np |
| 22 | |
Louis Verhaard | 0b8268a | 2020-08-05 16:11:29 +0200 | [diff] [blame] | 23 | from . import numeric_util |
Dwight Lidman | 9b43f84 | 2020-12-08 17:56:44 +0100 | [diff] [blame] | 24 | from .high_level_command_stream import DMA |
| 25 | from .high_level_command_stream import NpuStripe |
Louis Verhaard | b9fc33c | 2020-08-13 11:47:36 +0200 | [diff] [blame] | 26 | from .tensor import create_const_tensor |
Louis Verhaard | 9db529a | 2020-09-23 10:27:11 +0200 | [diff] [blame] | 27 | from .tensor import create_equivalence_id |
Louis Verhaard | 0b8268a | 2020-08-05 16:11:29 +0200 | [diff] [blame] | 28 | from .tensor import TensorPurpose |
| 29 | |
| 30 | |
Louis Verhaard | 0b8268a | 2020-08-05 16:11:29 +0200 | [diff] [blame] | 31 | class LUTState: |
| 32 | # Tracks which LUT-s are located in SHRAM. |
| 33 | def __init__(self): |
| 34 | self.tensors = [] |
| 35 | |
| 36 | def get_equivalent(self, lut_tens): |
Jacob Bohlin | 1a66697 | 2020-09-11 10:04:15 +0200 | [diff] [blame] | 37 | # Returns existing lut with the same values, None if not found |
Louis Verhaard | 0b8268a | 2020-08-05 16:11:29 +0200 | [diff] [blame] | 38 | for t in self.tensors: |
Jacob Bohlin | 1a66697 | 2020-09-11 10:04:15 +0200 | [diff] [blame] | 39 | if np.array_equal(t.values, lut_tens.values): |
Louis Verhaard | 0b8268a | 2020-08-05 16:11:29 +0200 | [diff] [blame] | 40 | return t |
| 41 | return None |
| 42 | |
| 43 | def put(self, lut_tens): |
| 44 | # Returns new LUT state containing given tensor + all tensors in this state |
| 45 | # that do not overlap with the given tensor |
| 46 | new_state = LUTState() |
| 47 | new_state.tensors.append(lut_tens) |
| 48 | start = lut_tens.address |
| 49 | end = start + lut_tens.storage_size() |
| 50 | for tens in self.tensors: |
| 51 | start2 = tens.address |
| 52 | end2 = start2 + tens.storage_size() |
| 53 | if not numeric_util.overlaps(start, end, start2, end2): |
| 54 | new_state.tensors.append(tens) |
Jacob Bohlin | 1a66697 | 2020-09-11 10:04:15 +0200 | [diff] [blame] | 55 | |
Louis Verhaard | 0b8268a | 2020-08-05 16:11:29 +0200 | [diff] [blame] | 56 | return new_state |
| 57 | |
| 58 | def find_best_address(self, start, stop, step): |
| 59 | # Finds the address in the given range that overlaps with the minimum number of |
| 60 | # currently present LUT-s. |
| 61 | # An improvement would be to also take future LUT usage into account |
| 62 | best_addr = start |
| 63 | best_nr_overlaps = stop |
| 64 | for addr in range(start, stop, step): |
| 65 | nr_overlaps = 0 |
| 66 | for tens in self.tensors: |
| 67 | start2 = tens.address |
| 68 | end2 = start2 + tens.storage_size() |
| 69 | if numeric_util.overlaps(addr, addr + step, start2, end2): |
| 70 | nr_overlaps += 1 |
| 71 | if nr_overlaps < best_nr_overlaps: |
| 72 | best_nr_overlaps = nr_overlaps |
| 73 | best_addr = addr |
| 74 | return best_addr |
| 75 | |
| 76 | |
| 77 | def get_lut_index(arch, lut_tensor): |
| 78 | # Returns the index in SHRAM where the given LUT is stored, a value between 0 and 8 |
| 79 | slot = (lut_tensor.address - arch.shram_lut_address) // lut_tensor.storage_size() |
| 80 | assert 0 <= slot < 8 |
| 81 | return slot |
| 82 | |
| 83 | |
Louis Verhaard | b9fc33c | 2020-08-13 11:47:36 +0200 | [diff] [blame] | 84 | def create_lut_tensor(name, values, dtype): |
| 85 | # Creates constant LUT tensor with the given values as lookup table. |
| 86 | # The tensor's equivalence_id is based on these values, so if multiple |
| 87 | # LUT tensors are created with identical values, they will get the same |
| 88 | # address in constant memory, and unnecessary DMA operations can be avoided. |
| 89 | sz = len(values) |
| 90 | assert sz in (256, 512) |
Tim Hall | 3b1578e | 2023-01-13 17:57:25 +0000 | [diff] [blame] | 91 | tens = create_const_tensor(name, [1, 1, 1, sz], dtype, values, TensorPurpose.LUT) |
Louis Verhaard | b9fc33c | 2020-08-13 11:47:36 +0200 | [diff] [blame] | 92 | tens.equivalence_id = create_equivalence_id(tuple(values)) |
| 93 | return tens |
| 94 | |
| 95 | |
Louis Verhaard | 0b8268a | 2020-08-05 16:11:29 +0200 | [diff] [blame] | 96 | def optimize_high_level_cmd_stream(sg, arch): |
| 97 | # - Allocates SHRAM address/lut index to LUT tensors |
| 98 | # - Removes unnecessary DMA operations of LUT-s that are already present in SHRAM from sg's command stream |
| 99 | cmd_stream = [] # will contain existing command stream minus unneeded DMA operations |
| 100 | lut_state = LUTState() |
| 101 | slot_size = 256 |
| 102 | lut_start = arch.shram_lut_address |
| 103 | lut_end = lut_start + arch.shram_lut_size |
| 104 | for cmd in sg.high_level_command_stream: |
Dwight Lidman | 9b43f84 | 2020-12-08 17:56:44 +0100 | [diff] [blame] | 105 | if isinstance(cmd, NpuStripe) and cmd.ps.lut_tensor is None and arch.shram_reserved_unused_banks == 0: |
Louis Verhaard | 0b8268a | 2020-08-05 16:11:29 +0200 | [diff] [blame] | 106 | # The command overwrites the last 2 banks containing the LUT; next LUT operation will require DMA |
| 107 | # TODO: check the command's SHRAM usage in more detail to determine if the LUT is overwritten or not |
| 108 | lut_state = LUTState() |
Dwight Lidman | 9b43f84 | 2020-12-08 17:56:44 +0100 | [diff] [blame] | 109 | if not isinstance(cmd, DMA) or cmd.out_tensor.purpose != TensorPurpose.LUT: |
Louis Verhaard | 0b8268a | 2020-08-05 16:11:29 +0200 | [diff] [blame] | 110 | # Non-LUT operation; leave untouched |
| 111 | cmd_stream.append(cmd) |
| 112 | continue |
| 113 | # LUT DMA operation |
| 114 | lut_tens = cmd.out_tensor |
| 115 | existing_tens = lut_state.get_equivalent(lut_tens) |
| 116 | if existing_tens is not None: |
| 117 | # LUT is already in SHRAM, no need to perform DMA |
Johan Alfvén | 91c5a14 | 2021-12-18 16:45:44 +0100 | [diff] [blame] | 118 | lut_tens.equivalence_id = existing_tens.equivalence_id |
Louis Verhaard | 0b8268a | 2020-08-05 16:11:29 +0200 | [diff] [blame] | 119 | lut_tens.address = existing_tens.address |
Louis Verhaard | e8a5a78 | 2020-11-02 18:04:27 +0100 | [diff] [blame] | 120 | cmd.ps.primary_op.activation.lut_index = get_lut_index(arch, existing_tens) |
Louis Verhaard | 0b8268a | 2020-08-05 16:11:29 +0200 | [diff] [blame] | 121 | continue |
| 122 | # Place the LUT in the last 2 blocks of SHRAM |
| 123 | # Alignment is always on the size of the LUT, 256 for 256-byte LUT, 1K for 1K LUT, etc |
| 124 | address = lut_state.find_best_address(lut_start, lut_end, lut_tens.storage_size()) |
Jacob Bohlin | 1a66697 | 2020-09-11 10:04:15 +0200 | [diff] [blame] | 125 | lut_tens.equivalence_id = uuid.uuid4() |
Louis Verhaard | 0b8268a | 2020-08-05 16:11:29 +0200 | [diff] [blame] | 126 | lut_tens.address = address |
Louis Verhaard | e8a5a78 | 2020-11-02 18:04:27 +0100 | [diff] [blame] | 127 | cmd.ps.primary_op.activation.lut_index = (address - lut_start) // slot_size |
Louis Verhaard | 0b8268a | 2020-08-05 16:11:29 +0200 | [diff] [blame] | 128 | lut_state = lut_state.put(lut_tens) |
| 129 | cmd_stream.append(cmd) |
| 130 | sg.high_level_command_stream = cmd_stream |