blob: 7b563b087f5f99ecd46b4e00aa90e29e24197d6f [file] [log] [blame]
Johan Alfven55d90dd2024-04-02 16:32:54 +02001# SPDX-FileCopyrightText: Copyright 2020-2021, 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
Louis Verhaard0b8268a2020-08-05 16:11:29 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Louis Verhaard0b8268a2020-08-05 16:11:29 +020017# Description:
18# Functionality for lookup table support.
19import uuid
Louis Verhaard0b8268a2020-08-05 16:11:29 +020020
Louis Verhaardb9fc33c2020-08-13 11:47:36 +020021import numpy as np
22
Johan Alfven8e525ca2023-05-07 13:12:37 +020023from . import fp_math
Louis Verhaard0b8268a2020-08-05 16:11:29 +020024from . import numeric_util
Johan Alfvence502732023-04-24 13:35:40 +020025from .data_type import DataType
26from .debug_database import DebugDatabase
Dwight Lidman9b43f842020-12-08 17:56:44 +010027from .high_level_command_stream import DMA
28from .high_level_command_stream import NpuStripe
Johan Alfvence502732023-04-24 13:35:40 +020029from .numeric_util import round_away_zero
30from .operation import Op
Johan Alfven8e525ca2023-05-07 13:12:37 +020031from .scaling import quantise_scale
Louis Verhaardb9fc33c2020-08-13 11:47:36 +020032from .tensor import create_const_tensor
Louis Verhaard9db529a2020-09-23 10:27:11 +020033from .tensor import create_equivalence_id
Johan Alfvence502732023-04-24 13:35:40 +020034from .tensor import QuantizationParameters
Louis Verhaard0b8268a2020-08-05 16:11:29 +020035from .tensor import TensorPurpose
36
37
Louis Verhaard0b8268a2020-08-05 16:11:29 +020038class LUTState:
39 # Tracks which LUT-s are located in SHRAM.
40 def __init__(self):
41 self.tensors = []
42
43 def get_equivalent(self, lut_tens):
Jacob Bohlin1a666972020-09-11 10:04:15 +020044 # Returns existing lut with the same values, None if not found
Louis Verhaard0b8268a2020-08-05 16:11:29 +020045 for t in self.tensors:
Jacob Bohlin1a666972020-09-11 10:04:15 +020046 if np.array_equal(t.values, lut_tens.values):
Louis Verhaard0b8268a2020-08-05 16:11:29 +020047 return t
48 return None
49
50 def put(self, lut_tens):
51 # Returns new LUT state containing given tensor + all tensors in this state
52 # that do not overlap with the given tensor
53 new_state = LUTState()
54 new_state.tensors.append(lut_tens)
55 start = lut_tens.address
56 end = start + lut_tens.storage_size()
57 for tens in self.tensors:
58 start2 = tens.address
59 end2 = start2 + tens.storage_size()
60 if not numeric_util.overlaps(start, end, start2, end2):
61 new_state.tensors.append(tens)
Jacob Bohlin1a666972020-09-11 10:04:15 +020062
Louis Verhaard0b8268a2020-08-05 16:11:29 +020063 return new_state
64
65 def find_best_address(self, start, stop, step):
66 # Finds the address in the given range that overlaps with the minimum number of
67 # currently present LUT-s.
68 # An improvement would be to also take future LUT usage into account
69 best_addr = start
70 best_nr_overlaps = stop
71 for addr in range(start, stop, step):
72 nr_overlaps = 0
73 for tens in self.tensors:
74 start2 = tens.address
75 end2 = start2 + tens.storage_size()
76 if numeric_util.overlaps(addr, addr + step, start2, end2):
77 nr_overlaps += 1
78 if nr_overlaps < best_nr_overlaps:
79 best_nr_overlaps = nr_overlaps
80 best_addr = addr
81 return best_addr
82
83
84def get_lut_index(arch, lut_tensor):
85 # Returns the index in SHRAM where the given LUT is stored, a value between 0 and 8
Johan Alfven55d90dd2024-04-02 16:32:54 +020086 slot = (lut_tensor.address - arch.shram_lut_address) // arch.shram_lut_slot_size
Louis Verhaard0b8268a2020-08-05 16:11:29 +020087 assert 0 <= slot < 8
88 return slot
89
90
Louis Verhaardb9fc33c2020-08-13 11:47:36 +020091def create_lut_tensor(name, values, dtype):
92 # Creates constant LUT tensor with the given values as lookup table.
93 # The tensor's equivalence_id is based on these values, so if multiple
94 # LUT tensors are created with identical values, they will get the same
95 # address in constant memory, and unnecessary DMA operations can be avoided.
96 sz = len(values)
97 assert sz in (256, 512)
Johan Alfvence502732023-04-24 13:35:40 +020098 # int16 lut uses uint32 lut with base + slope
99 dtype = DataType.uint32 if dtype == DataType.int16 else dtype
Tim Hall3b1578e2023-01-13 17:57:25 +0000100 tens = create_const_tensor(name, [1, 1, 1, sz], dtype, values, TensorPurpose.LUT)
Louis Verhaardb9fc33c2020-08-13 11:47:36 +0200101 tens.equivalence_id = create_equivalence_id(tuple(values))
102 return tens
103
104
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200105def optimize_high_level_cmd_stream(sg, arch):
106 # - Allocates SHRAM address/lut index to LUT tensors
107 # - Removes unnecessary DMA operations of LUT-s that are already present in SHRAM from sg's command stream
108 cmd_stream = [] # will contain existing command stream minus unneeded DMA operations
109 lut_state = LUTState()
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200110 lut_start = arch.shram_lut_address
111 lut_end = lut_start + arch.shram_lut_size
112 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100113 if isinstance(cmd, NpuStripe) and cmd.ps.lut_tensor is None and arch.shram_reserved_unused_banks == 0:
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200114 # The command overwrites the last 2 banks containing the LUT; next LUT operation will require DMA
115 # TODO: check the command's SHRAM usage in more detail to determine if the LUT is overwritten or not
116 lut_state = LUTState()
Dwight Lidman9b43f842020-12-08 17:56:44 +0100117 if not isinstance(cmd, DMA) or cmd.out_tensor.purpose != TensorPurpose.LUT:
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200118 # Non-LUT operation; leave untouched
119 cmd_stream.append(cmd)
120 continue
121 # LUT DMA operation
122 lut_tens = cmd.out_tensor
123 existing_tens = lut_state.get_equivalent(lut_tens)
124 if existing_tens is not None:
125 # LUT is already in SHRAM, no need to perform DMA
Johan Alfvén91c5a142021-12-18 16:45:44 +0100126 lut_tens.equivalence_id = existing_tens.equivalence_id
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200127 lut_tens.address = existing_tens.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100128 cmd.ps.primary_op.activation.lut_index = get_lut_index(arch, existing_tens)
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200129 continue
130 # Place the LUT in the last 2 blocks of SHRAM
131 # Alignment is always on the size of the LUT, 256 for 256-byte LUT, 1K for 1K LUT, etc
132 address = lut_state.find_best_address(lut_start, lut_end, lut_tens.storage_size())
Johan Alfven55d90dd2024-04-02 16:32:54 +0200133
Jacob Bohlin1a666972020-09-11 10:04:15 +0200134 lut_tens.equivalence_id = uuid.uuid4()
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200135 lut_tens.address = address
Johan Alfven55d90dd2024-04-02 16:32:54 +0200136 cmd.ps.primary_op.activation.lut_index = (address - lut_start) // arch.shram_lut_slot_size
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200137 lut_state = lut_state.put(lut_tens)
138 cmd_stream.append(cmd)
139 sg.high_level_command_stream = cmd_stream
Johan Alfvence502732023-04-24 13:35:40 +0200140
141
142def convert_to_lut(op, lut_values, lut_name):
143 # Rewrite the operation by Add with scalar 0 + LUT activation
144 ifm = op.ifm
145 ofm = op.ofm
146 if ifm is None:
147 return op
148 assert ifm.dtype in (DataType.int8, DataType.uint8, DataType.int16)
149 op.type = Op.Add
150 op.name = f"{op.name}_lut_{lut_name}"
151 # Mark as no-op to enable potential fusing optimizations
152 op.attrs["is_nop"] = True
153 # Create an input tensor containing scalar zero
154 _max = 65536.0 if ifm.dtype == DataType.int16 else 255.0
155 quantization = QuantizationParameters(0.0, _max)
156 quantization.scale_f32 = ifm.quantization.scale_f32
157 quantization.zero_point = 0
158 tens = create_const_tensor(ifm.name + "_scalar0", [], ifm.dtype, [0], quantization=quantization)
159 op.add_input_tensor(tens)
160
161 # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
162 # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
163 # should be the same as the IFM
164 op.forced_output_quantization = ifm.quantization
165
166 # the lut tensor datatype needs to match both; the ofm datatype, because these are the values output; and the
167 # datatype used to generate the lut values (which is probably the ifm datatype), because we want to avoid any
168 # potential overflow errors in create_lut_tensor() caused by converting Python int (which could represent a uint)
169 # to NumPy int. this can be guaranteed by checking that the ifm and ofm datatypes are the same
170 assert ifm.dtype == ofm.dtype
171 lut_tensor = create_lut_tensor(op.name + "_values", lut_values, ofm.dtype)
172 op.set_activation_lut(lut_tensor)
173 op.set_ifm_ofm_shapes()
174 DebugDatabase.add_optimised(op, op)
175 return op
176
177
178def create_lut_8bit_op(op, lut_fn, fn_name):
179 ifm_scale = op.ifm.quantization.scale_f32
180 ofm_scale = op.ofm.quantization.scale_f32
181 zp_in = op.ifm.quantization.zero_point
182 zp_out = op.ofm.quantization.zero_point
183
184 values = []
185 ix = range(256) if op.ifm.dtype == DataType.uint8 else range(-128, 128)
186 quantized_min = min(ix)
187 quantized_max = max(ix)
188 for x in ix:
189 x_real = ifm_scale * (x - zp_in)
190 y_real = lut_fn(x_real)
191 lut_result = round_away_zero(y_real / ofm_scale) + zp_out
192 lut_result = min(quantized_max, max(quantized_min, lut_result))
193 values.append(lut_result)
194
195 return convert_to_lut(op, values, fn_name)
196
197
198def create_lut_int16_op(op, lut_fn, fn_name):
199 ifm_scale = op.ifm.quantization.scale_f32
200 ofm_scale = op.ofm.quantization.scale_f32
201 zp_in = op.ifm.quantization.zero_point
202 zp_out = op.ofm.quantization.zero_point
203
204 input_min = ifm_scale * (np.iinfo(np.int16).min - zp_in)
205 input_max = ifm_scale * (np.iinfo(np.int16).max - zp_in)
206 output_min = ofm_scale * (np.iinfo(np.int16).min - zp_out)
207 output_max = ofm_scale * (np.iinfo(np.int16).max - zp_out)
208
209 # Create 16bit lut following the reference
210 nbr_steps = 512
211 step = (input_max - input_min) / nbr_steps
212 half_step = step / 2
213 output_scaling_inv = (np.iinfo(np.int16).max - np.iinfo(np.int16).min + 1) / (output_max - output_min)
214
215 table_min = np.iinfo(np.int16).min
216 table_max = np.iinfo(np.int16).max
217
218 values = []
219 for i in range(nbr_steps):
220 val = lut_fn(input_min + i * step)
221 val_midpoint = lut_fn(input_min + i * step + half_step)
222 val_next = lut_fn(input_min + (i + 1) * step)
223
224 sample_val = round_away_zero(val * output_scaling_inv)
225 midpoint_interp_val = round_away_zero(
226 (val_next * output_scaling_inv + round_away_zero(val * output_scaling_inv)) / 2
227 )
228 midpoint_val = round_away_zero(val_midpoint * output_scaling_inv)
229 midpoint_err = midpoint_interp_val - midpoint_val
230 bias = round_away_zero(midpoint_err / 2)
231
232 lut_result = min(max(sample_val - bias, table_min), table_max)
233 values.append(lut_result)
234
235 val = round_away_zero(lut_fn(input_max) * output_scaling_inv)
236 lut_result = min(max(val, table_min), table_max)
237 values.append(lut_result)
238
239 # Convert to hardware 16bit lut with base and slope
240 lut = [0] * nbr_steps
241 for i in range(nbr_steps):
242 slope = (int(values[i + 1]) - int(values[i])) << 16
243 base = int(values[i])
244 lut[i] = slope + base
245
246 return convert_to_lut(op, lut, fn_name)
Johan Alfven8e525ca2023-05-07 13:12:37 +0200247
248
249def create_lut_rsqrt_int8_op(op):
250 # Turn off black formatting for the LUT tables to keep them compact
251 # fmt: off
252
253 # RSQRT_LUT has been generated by printing the output from the reference.
254 # These values are always the same but for some unknown reason it is not being
255 # implemented as a LUT in the reference.
256 # So based on the input range (-128, 127) the reference produces the following output:
257 RSQRT_LUT = [
258 0x00000000, 0x00100000, 0x000b504e, 0x00093cd4, 0x00080000, 0x000727c9, 0x0006882f, 0x00060c24,
259 0x0005a827, 0x00055555, 0x00050f45, 0x0004d2fe, 0x00049e6a, 0x00047007, 0x000446b4, 0x00042195,
260 0x00040000, 0x0003e16d, 0x0003c570, 0x0003abb0, 0x000393e5, 0x00037dd2, 0x00036945, 0x00035613,
261 0x00034418, 0x00033333, 0x0003234b, 0x00031447, 0x00030612, 0x0002f89c, 0x0002ebd3, 0x0002dfaa,
262 0x0002d414, 0x0002c906, 0x0002be75, 0x0002b45a, 0x0002aaab, 0x0002a161, 0x00029875, 0x00028fe3,
263 0x000287a2, 0x00027fb0, 0x00027807, 0x000270a2, 0x0002697f, 0x00026298, 0x00025bec, 0x00025577,
264 0x00024f35, 0x00024925, 0x00024343, 0x00023d8e, 0x00023803, 0x000232a1, 0x00022d65, 0x0002284e,
265 0x0002235a, 0x00021e87, 0x000219d5, 0x00021541, 0x000210cb, 0x00020c70, 0x00020831, 0x0002040c,
266 0x00020000, 0x0001fc0c, 0x0001f82f, 0x0001f468, 0x0001f0b7, 0x0001ed1a, 0x0001e991, 0x0001e61b,
267 0x0001e2b8, 0x0001df67, 0x0001dc26, 0x0001d8f7, 0x0001d5d8, 0x0001d2c8, 0x0001cfc8, 0x0001ccd6,
268 0x0001c9f2, 0x0001c71c, 0x0001c454, 0x0001c198, 0x0001bee9, 0x0001bc46, 0x0001b9af, 0x0001b723,
269 0x0001b4a3, 0x0001b22d, 0x0001afc2, 0x0001ad61, 0x0001ab0a, 0x0001a8bc, 0x0001a678, 0x0001a43e,
270 0x0001a20c, 0x00019fe3, 0x00019dc2, 0x00019baa, 0x0001999a, 0x00019791, 0x00019590, 0x00019397,
271 0x000191a5, 0x00018fbb, 0x00018dd7, 0x00018bfa, 0x00018a23, 0x00018853, 0x0001868a, 0x000184c6,
272 0x00018309, 0x00018152, 0x00017fa0, 0x00017df4, 0x00017c4e, 0x00017aad, 0x00017911, 0x0001777b,
273 0x000175e9, 0x0001745d, 0x000172d6, 0x00017153, 0x00016fd5, 0x00016e5b, 0x00016ce7, 0x00016b76,
274 0x00016a0a, 0x000168a2, 0x0001673e, 0x000165de, 0x00016483, 0x0001632b, 0x000161d7, 0x00016087,
275 0x00015f3b, 0x00015df2, 0x00015cad, 0x00015b6b, 0x00015a2d, 0x000158f2, 0x000157bb, 0x00015686,
276 0x00015555, 0x00015427, 0x000152fd, 0x000151d5, 0x000150b0, 0x00014f8f, 0x00014e70, 0x00014d54,
277 0x00014c3b, 0x00014b24, 0x00014a11, 0x00014900, 0x000147f1, 0x000146e5, 0x000145dc, 0x000144d5,
278 0x000143d1, 0x000142cf, 0x000141d0, 0x000140d3, 0x00013fd8, 0x00013ee0, 0x00013de9, 0x00013cf5,
279 0x00013c03, 0x00013b14, 0x00013a26, 0x0001393b, 0x00013851, 0x0001376a, 0x00013684, 0x000135a1,
280 0x000134bf, 0x000133e0, 0x00013302, 0x00013226, 0x0001314c, 0x00013074, 0x00012f9e, 0x00012ec9,
281 0x00012df6, 0x00012d25, 0x00012c55, 0x00012b87, 0x00012abb, 0x000129f1, 0x00012928, 0x00012860,
282 0x0001279a, 0x000126d6, 0x00012613, 0x00012552, 0x00012492, 0x000123d4, 0x00012317, 0x0001225c,
283 0x000121a2, 0x000120e9, 0x00012032, 0x00011f7c, 0x00011ec7, 0x00011e14, 0x00011d62, 0x00011cb1,
284 0x00011c02, 0x00011b54, 0x00011aa7, 0x000119fb, 0x00011950, 0x000118a7, 0x000117ff, 0x00011758,
285 0x000116b3, 0x0001160e, 0x0001156b, 0x000114c8, 0x00011427, 0x00011387, 0x000112e8, 0x0001124a,
286 0x000111ad, 0x00011111, 0x00011076, 0x00010fdc, 0x00010f44, 0x00010eac, 0x00010e15, 0x00010d7f,
287 0x00010cea, 0x00010c56, 0x00010bc4, 0x00010b32, 0x00010aa0, 0x00010a10, 0x00010981, 0x000108f3,
288 0x00010865, 0x000107d9, 0x0001074d, 0x000106c2, 0x00010638, 0x000105af, 0x00010527, 0x0001049f,
289 0x00010419, 0x00010393, 0x0001030e, 0x0001028a, 0x00010206, 0x00010183, 0x00010102, 0x00010080
290 ]
291
292 # Transform the above LUT so it gets the correct quantization (following the reference)
293 ifm_scale = op.ifm.quantization.scale_f32
294 ofm_scale = op.ofm.quantization.scale_f32
295 zp_in = op.ifm.quantization.zero_point
296 zp_out = op.ofm.quantization.zero_point
297
Johan Alfven8e525ca2023-05-07 13:12:37 +0200298 scale = np.double(1) / np.double(np.sqrt(ifm_scale) * ofm_scale)
299 output_multiplier, output_shift = quantise_scale(scale)
300
301 # Shift modification (value used in reference but Vela has opposite sign)
302 kshift = -20
303
304 ix = range(-128, 128)
305 quantized_min = min(ix)
306 quantized_max = max(ix)
307
308 # Any value close to 0 (zero index in LUT) is mapped to the max output value
309 values = [quantized_max]
310 for x in ix:
311 if x == -128:
312 # Value already populated above
313 continue
Johan Alfven3db30ff2023-08-16 12:18:09 +0200314 # Rsqrt is only defined for positive values
315 x_real = max(0, x - zp_in)
Johan Alfven8e525ca2023-05-07 13:12:37 +0200316 val = RSQRT_LUT[x_real]
317 val = fp_math.multiply_by_quantized_multiplier(val, output_multiplier, output_shift - kshift) + zp_out
318 lut_result = min(quantized_max, max(quantized_min, val))
319 values.append(lut_result)
320
321 return convert_to_lut(op, values, "rsqrt")