blob: 39101facf0417baeeb06b30a328ffafa992a45e4 [file] [log] [blame]
Louis Verhaard0b8268a2020-08-05 16:11:29 +02001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Functionality for lookup table support.
18import uuid
19from functools import lru_cache
20
21from . import numeric_util
22from .high_level_command_stream import CommandType
23from .tensor import TensorPurpose
24
25
26@lru_cache(maxsize=None)
27def create_equivalence_id(key):
28 # Generates equivalence_id based on key.
29 # The DMA optimization of LUT-s assumes that 2 LUT tensors are identical
30 # if they have the same equivalence_id.
31 # So for example all created 256-byte tanh LUT tensors should have
32 # the same equivalence id.
33 return uuid.uuid4()
34
35
36class LUTState:
37 # Tracks which LUT-s are located in SHRAM.
38 def __init__(self):
39 self.tensors = []
40
41 def get_equivalent(self, lut_tens):
42 # Returns existing lut with same equivalence id, None if not found
43 for t in self.tensors:
44 if t.equivalent(lut_tens):
45 return t
46 return None
47
48 def put(self, lut_tens):
49 # Returns new LUT state containing given tensor + all tensors in this state
50 # that do not overlap with the given tensor
51 new_state = LUTState()
52 new_state.tensors.append(lut_tens)
53 start = lut_tens.address
54 end = start + lut_tens.storage_size()
55 for tens in self.tensors:
56 start2 = tens.address
57 end2 = start2 + tens.storage_size()
58 if not numeric_util.overlaps(start, end, start2, end2):
59 new_state.tensors.append(tens)
60 return new_state
61
62 def find_best_address(self, start, stop, step):
63 # Finds the address in the given range that overlaps with the minimum number of
64 # currently present LUT-s.
65 # An improvement would be to also take future LUT usage into account
66 best_addr = start
67 best_nr_overlaps = stop
68 for addr in range(start, stop, step):
69 nr_overlaps = 0
70 for tens in self.tensors:
71 start2 = tens.address
72 end2 = start2 + tens.storage_size()
73 if numeric_util.overlaps(addr, addr + step, start2, end2):
74 nr_overlaps += 1
75 if nr_overlaps < best_nr_overlaps:
76 best_nr_overlaps = nr_overlaps
77 best_addr = addr
78 return best_addr
79
80
81def get_lut_index(arch, lut_tensor):
82 # Returns the index in SHRAM where the given LUT is stored, a value between 0 and 8
83 slot = (lut_tensor.address - arch.shram_lut_address) // lut_tensor.storage_size()
84 assert 0 <= slot < 8
85 return slot
86
87
88def optimize_high_level_cmd_stream(sg, arch):
89 # - Allocates SHRAM address/lut index to LUT tensors
90 # - Removes unnecessary DMA operations of LUT-s that are already present in SHRAM from sg's command stream
91 cmd_stream = [] # will contain existing command stream minus unneeded DMA operations
92 lut_state = LUTState()
93 slot_size = 256
94 lut_start = arch.shram_lut_address
95 lut_end = lut_start + arch.shram_lut_size
96 for cmd in sg.high_level_command_stream:
97 if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.lut_tensor is None and arch.shram_reserved_unused_banks == 0:
98 # The command overwrites the last 2 banks containing the LUT; next LUT operation will require DMA
99 # TODO: check the command's SHRAM usage in more detail to determine if the LUT is overwritten or not
100 lut_state = LUTState()
101 if cmd.cmdtype != CommandType.DMA or cmd.out_tensor.purpose != TensorPurpose.LUT:
102 # Non-LUT operation; leave untouched
103 cmd_stream.append(cmd)
104 continue
105 # LUT DMA operation
106 lut_tens = cmd.out_tensor
107 existing_tens = lut_state.get_equivalent(lut_tens)
108 if existing_tens is not None:
109 # LUT is already in SHRAM, no need to perform DMA
110 lut_tens.address = existing_tens.address
111 cmd.ps.primary_op.attrs["lut_index"] = get_lut_index(arch, existing_tens)
112 continue
113 # Place the LUT in the last 2 blocks of SHRAM
114 # Alignment is always on the size of the LUT, 256 for 256-byte LUT, 1K for 1K LUT, etc
115 address = lut_state.find_best_address(lut_start, lut_end, lut_tens.storage_size())
116 lut_tens.address = address
117 cmd.ps.primary_op.attrs["lut_index"] = (address - lut_start) // slot_size
118 lut_state = lut_state.put(lut_tens)
119 cmd_stream.append(cmd)
120 sg.high_level_command_stream = cmd_stream