blob: fdf9d0ff162416f5eda67cb8c4b4dfd6830d3dce [file] [log] [blame]
Rickard Bolinbc6ee582022-11-04 08:24:29 +00001# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
Louis Verhaard0b8268a2020-08-05 16:11:29 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Louis Verhaard0b8268a2020-08-05 16:11:29 +020017# Description:
18# Functionality for lookup table support.
19import uuid
Louis Verhaard0b8268a2020-08-05 16:11:29 +020020
Louis Verhaardb9fc33c2020-08-13 11:47:36 +020021import numpy as np
22
Louis Verhaard0b8268a2020-08-05 16:11:29 +020023from . import numeric_util
Dwight Lidman9b43f842020-12-08 17:56:44 +010024from .high_level_command_stream import DMA
25from .high_level_command_stream import NpuStripe
Louis Verhaardb9fc33c2020-08-13 11:47:36 +020026from .tensor import create_const_tensor
Louis Verhaard9db529a2020-09-23 10:27:11 +020027from .tensor import create_equivalence_id
Louis Verhaard0b8268a2020-08-05 16:11:29 +020028from .tensor import TensorPurpose
29
30
Louis Verhaard0b8268a2020-08-05 16:11:29 +020031class LUTState:
32 # Tracks which LUT-s are located in SHRAM.
33 def __init__(self):
34 self.tensors = []
35
36 def get_equivalent(self, lut_tens):
Jacob Bohlin1a666972020-09-11 10:04:15 +020037 # Returns existing lut with the same values, None if not found
Louis Verhaard0b8268a2020-08-05 16:11:29 +020038 for t in self.tensors:
Jacob Bohlin1a666972020-09-11 10:04:15 +020039 if np.array_equal(t.values, lut_tens.values):
Louis Verhaard0b8268a2020-08-05 16:11:29 +020040 return t
41 return None
42
43 def put(self, lut_tens):
44 # Returns new LUT state containing given tensor + all tensors in this state
45 # that do not overlap with the given tensor
46 new_state = LUTState()
47 new_state.tensors.append(lut_tens)
48 start = lut_tens.address
49 end = start + lut_tens.storage_size()
50 for tens in self.tensors:
51 start2 = tens.address
52 end2 = start2 + tens.storage_size()
53 if not numeric_util.overlaps(start, end, start2, end2):
54 new_state.tensors.append(tens)
Jacob Bohlin1a666972020-09-11 10:04:15 +020055
Louis Verhaard0b8268a2020-08-05 16:11:29 +020056 return new_state
57
58 def find_best_address(self, start, stop, step):
59 # Finds the address in the given range that overlaps with the minimum number of
60 # currently present LUT-s.
61 # An improvement would be to also take future LUT usage into account
62 best_addr = start
63 best_nr_overlaps = stop
64 for addr in range(start, stop, step):
65 nr_overlaps = 0
66 for tens in self.tensors:
67 start2 = tens.address
68 end2 = start2 + tens.storage_size()
69 if numeric_util.overlaps(addr, addr + step, start2, end2):
70 nr_overlaps += 1
71 if nr_overlaps < best_nr_overlaps:
72 best_nr_overlaps = nr_overlaps
73 best_addr = addr
74 return best_addr
75
76
77def get_lut_index(arch, lut_tensor):
78 # Returns the index in SHRAM where the given LUT is stored, a value between 0 and 8
79 slot = (lut_tensor.address - arch.shram_lut_address) // lut_tensor.storage_size()
80 assert 0 <= slot < 8
81 return slot
82
83
Louis Verhaardb9fc33c2020-08-13 11:47:36 +020084def create_lut_tensor(name, values, dtype):
85 # Creates constant LUT tensor with the given values as lookup table.
86 # The tensor's equivalence_id is based on these values, so if multiple
87 # LUT tensors are created with identical values, they will get the same
88 # address in constant memory, and unnecessary DMA operations can be avoided.
89 sz = len(values)
90 assert sz in (256, 512)
91 ntype = np.uint8 if dtype.size_in_bytes() == 1 else np.uint32
92 tens = create_const_tensor(name, [1, 1, 1, sz], dtype, values, ntype, TensorPurpose.LUT)
93 tens.equivalence_id = create_equivalence_id(tuple(values))
94 return tens
95
96
Louis Verhaard0b8268a2020-08-05 16:11:29 +020097def optimize_high_level_cmd_stream(sg, arch):
98 # - Allocates SHRAM address/lut index to LUT tensors
99 # - Removes unnecessary DMA operations of LUT-s that are already present in SHRAM from sg's command stream
100 cmd_stream = [] # will contain existing command stream minus unneeded DMA operations
101 lut_state = LUTState()
102 slot_size = 256
103 lut_start = arch.shram_lut_address
104 lut_end = lut_start + arch.shram_lut_size
105 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100106 if isinstance(cmd, NpuStripe) and cmd.ps.lut_tensor is None and arch.shram_reserved_unused_banks == 0:
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200107 # The command overwrites the last 2 banks containing the LUT; next LUT operation will require DMA
108 # TODO: check the command's SHRAM usage in more detail to determine if the LUT is overwritten or not
109 lut_state = LUTState()
Dwight Lidman9b43f842020-12-08 17:56:44 +0100110 if not isinstance(cmd, DMA) or cmd.out_tensor.purpose != TensorPurpose.LUT:
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200111 # Non-LUT operation; leave untouched
112 cmd_stream.append(cmd)
113 continue
114 # LUT DMA operation
115 lut_tens = cmd.out_tensor
116 existing_tens = lut_state.get_equivalent(lut_tens)
117 if existing_tens is not None:
118 # LUT is already in SHRAM, no need to perform DMA
Johan Alfvén91c5a142021-12-18 16:45:44 +0100119 lut_tens.equivalence_id = existing_tens.equivalence_id
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200120 lut_tens.address = existing_tens.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100121 cmd.ps.primary_op.activation.lut_index = get_lut_index(arch, existing_tens)
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200122 continue
123 # Place the LUT in the last 2 blocks of SHRAM
124 # Alignment is always on the size of the LUT, 256 for 256-byte LUT, 1K for 1K LUT, etc
125 address = lut_state.find_best_address(lut_start, lut_end, lut_tens.storage_size())
Jacob Bohlin1a666972020-09-11 10:04:15 +0200126 lut_tens.equivalence_id = uuid.uuid4()
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200127 lut_tens.address = address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100128 cmd.ps.primary_op.activation.lut_index = (address - lut_start) // slot_size
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200129 lut_state = lut_state.put(lut_tens)
130 cmd_stream.append(cmd)
131 sg.high_level_command_stream = cmd_stream