Fredrik Svedberg | 0ac0804 | 2023-04-11 22:35:04 +0200 | [diff] [blame] | 1 | # SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com> |
| 2 | # |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | # |
| 5 | # Licensed under the Apache License, Version 2.0 (the License); you may |
| 6 | # not use this file except in compliance with the License. |
| 7 | # You may obtain a copy of the License at |
| 8 | # |
| 9 | # www.apache.org/licenses/LICENSE-2.0 |
| 10 | # |
| 11 | # Unless required by applicable law or agreed to in writing, software |
| 12 | # distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| 13 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | # See the License for the specific language governing permissions and |
| 15 | # limitations under the License. |
| 16 | # |
| 17 | # Description: |
| 18 | # Contains implementation of UnidirectionalSequenceLstm graph optimisation. |
| 19 | from enum import Enum |
| 20 | from typing import Tuple |
| 21 | |
| 22 | import numpy as np |
| 23 | |
| 24 | from .data_type import DataType |
| 25 | from .debug_database import DebugDatabase |
| 26 | from .graph_optimiser_util import create_avg_pool_for_concat |
| 27 | from .operation import ActivationFunction |
| 28 | from .operation import ExplicitScaling |
| 29 | from .operation import Op |
| 30 | from .operation import Operation |
| 31 | from .operation_util import create_add |
| 32 | from .operation_util import create_fullyconnected |
| 33 | from .operation_util import create_fused_activation |
| 34 | from .operation_util import create_mul |
| 35 | from .scaling import elementwise_mul_scale |
| 36 | from .shape4d import Shape4D |
| 37 | from .tensor import QuantizationParameters |
| 38 | from .tensor import Tensor |
| 39 | |
Fredrik Svedberg | 5fd155e | 2023-05-08 16:13:43 +0200 | [diff] [blame] | 40 | Q0_15_SCALE = np.float32(2**-15) |
Fredrik Svedberg | 0ac0804 | 2023-04-11 22:35:04 +0200 | [diff] [blame] | 41 | """Q0.15 scale like the reference defines it""" |
| 42 | |
| 43 | |
| 44 | class Lstm: |
| 45 | """Lstm graph optimisation. |
| 46 | |
| 47 | Unrolls a UNIDIRECTIONAL_SEQUENCE_LSTM operation into its basic operations. |
| 48 | |
| 49 | Usage: |
| 50 | |
| 51 | unrolled_op = Lstm(op).get_graph() |
| 52 | """ |
| 53 | |
| 54 | class State(Enum): |
| 55 | """States (variable tensors)""" |
| 56 | |
| 57 | OUTPUT = 18 # Value = tensor index |
| 58 | CELL = 19 # Value = tensor index |
| 59 | |
| 60 | def __init__(self, op): |
| 61 | self.op = op |
| 62 | |
| 63 | def get_graph(self) -> Operation: |
| 64 | """Return the generated graph implementation""" |
| 65 | self.op.ofm.ops = [] |
| 66 | if self.time_major: |
| 67 | output_state = self.get_initial_state(Lstm.State.OUTPUT) |
| 68 | cell_state = self.get_initial_state(Lstm.State.CELL) |
| 69 | for time in range(self.n_time): |
| 70 | feature = self.get_feature(time) |
| 71 | output_state, cell_state = self.lstm_step(feature, output_state, cell_state, time) |
| 72 | op = self.put_ofm(output_state, time) |
| 73 | else: |
| 74 | for batch in range(self.n_batch): |
| 75 | output_state = self.get_initial_state(Lstm.State.OUTPUT, batch) |
| 76 | cell_state = self.get_initial_state(Lstm.State.CELL, batch) |
| 77 | for time in range(self.n_time): |
| 78 | feature = self.get_feature(time, batch) |
| 79 | output_state, cell_state = self.lstm_step(feature, output_state, cell_state, time, batch) |
| 80 | op = self.put_ofm(output_state, time, batch) |
| 81 | return op |
| 82 | |
| 83 | def get_feature(self, time: int, batch: int = 0) -> Tensor: |
| 84 | """Get input feature for provided time and batch""" |
| 85 | feature = self.op.ifm.clone(f"_feature#{batch}.{time}") |
| 86 | feature.set_all_shapes([self.n_batch if self.time_major else 1, self.n_feature]) |
| 87 | op = Operation(Op.SplitSliceRead, feature.name) |
| 88 | op.add_input_tensor(self.op.ifm) |
| 89 | op.set_output_tensor(feature) |
| 90 | op.set_ifm_ofm_shapes() |
| 91 | offset = [time, 0, 0] if self.time_major else [batch, time, 0] |
| 92 | op.read_offsets[0] = Shape4D.from_list(offset, 0) |
| 93 | op.read_shapes[0] = op.ofm_shapes[0] |
| 94 | DebugDatabase.add_optimised(self.op, op) |
| 95 | return feature |
| 96 | |
| 97 | def get_initial_state(self, state_type: State, batch: int = 0) -> Tensor: |
| 98 | """Get state tensor for provided state type and batch""" |
| 99 | state = self.state(state_type) |
| 100 | if self.time_major: |
| 101 | # For time major just return the 2D state, since all batches |
| 102 | # are calculated at the same time |
| 103 | return state |
| 104 | else: |
| 105 | # For non time major return one batch of the 2D state |
| 106 | # by setting the read offset to the provided batch |
| 107 | |
| 108 | # The cloned state tensor will share equivalence id and buffer |
| 109 | # with the variable state tensor |
| 110 | n_state = state.shape[-1] |
| 111 | state_ofm = state.clone(f"_state#{batch}") |
| 112 | # Set shape to be one batch |
| 113 | state_ofm.set_all_shapes([1, n_state]) |
| 114 | # Create the op for reading one batch of the state |
| 115 | # (will be optimised away at a later stage) |
| 116 | op = Operation(Op.SplitSliceRead, state_ofm.name) |
| 117 | op.add_input_tensor(state) |
| 118 | op.set_output_tensor(state_ofm) |
| 119 | op.set_ifm_ofm_shapes() |
| 120 | # Set the read offset to the provided batch |
| 121 | op.read_offsets[0] = Shape4D.from_list([batch, 0], 0) |
| 122 | # Set the read shape to one batch, see above |
| 123 | op.read_shapes[0] = op.ofm_shapes[0] |
| 124 | DebugDatabase.add_optimised(self.op, op) |
| 125 | return state_ofm |
| 126 | |
| 127 | def get_state(self, op: Operation, batch: int = 0) -> Operation: |
| 128 | """Setup the correct read offset for reading the state from |
| 129 | a variable tensor state""" |
| 130 | if not self.time_major and self.n_batch > 1: |
| 131 | op.read_offsets[0] = Shape4D.from_list([batch, 0], 0) |
| 132 | op.read_shapes[0] = Shape4D(op.ifm.shape) |
| 133 | op.ifm_shapes[0] = Shape4D([self.n_batch, op.ifm.shape[-1]]) |
| 134 | return op |
| 135 | |
| 136 | def put_state(self, op: Operation, state_type: State, batch: int = 0) -> Operation: |
| 137 | """Save the state for the provided batch by pointing the operations |
| 138 | ofm to the variable state tensor""" |
| 139 | # The create op functions always return 4D shape, however the state |
| 140 | # should have 2D shape for correct operation |
| 141 | op.ofm.shape = op.ofm.shape[-2:] |
| 142 | # Get state from type |
| 143 | state = self.state(state_type) |
| 144 | # By using the same equivalence_id the backing buffer for the ofm |
| 145 | # tensor will be the state variable tensor buffer |
| 146 | op.ofm.equivalence_id = state.equivalence_id |
| 147 | # Set memory function which will make the tensor be in linear format |
| 148 | # just as the state variable tensor |
| 149 | op.memory_function = Op.VariableTensorWrite |
| 150 | # Set the batch write offset into the state tensor buffer unless |
| 151 | # time_major mode when all batches are written at once |
| 152 | if not self.time_major: |
| 153 | op.write_offset = Shape4D.from_list([batch, 0], 0) |
| 154 | op.write_shape = Shape4D(op.ofm.shape) |
| 155 | op.ofm_shapes = [Shape4D(state.shape)] |
| 156 | DebugDatabase.add_optimised(self.op, op) |
| 157 | return op |
| 158 | |
| 159 | def put_ofm(self, state: Tensor, time: int, batch: int = 0) -> Operation: |
| 160 | """Save the output state for the provided batch and time to OFM""" |
| 161 | name = f"{self.op.ofm.name}#{batch}.{time}" |
| 162 | offset = Shape4D.from_list([time, 0, 0] if self.time_major else [batch, time, 0], 0) |
| 163 | op = create_avg_pool_for_concat(self.op, name, state, Shape4D(state.shape), offset) |
| 164 | # The provided state tensor use the output state tensors buffer, so unless |
| 165 | # time_major mode we need to set the correct batch read offset |
| 166 | if not self.time_major: |
| 167 | op.read_offsets[0] = Shape4D.from_list([batch, 0], 0) |
| 168 | op.read_shapes[0] = Shape4D(state.shape) |
| 169 | op.ifm_shapes[0] = Shape4D(self.output_state.shape) |
| 170 | return op |
| 171 | |
| 172 | def lstm_step( |
| 173 | self, feature: Tensor, output_state: Tensor, cell_state: Tensor, time: int, batch: int = 0 |
| 174 | ) -> Tuple[Tensor, Tensor]: |
| 175 | """Generate one step of the LSTM implementation for the provided feature, batch and time""" |
| 176 | input_gate = self.calculate_gate( |
| 177 | f"input_gate#{batch}.{time}", |
| 178 | feature, |
| 179 | output_state, |
| 180 | self.input_to_input_weights, |
| 181 | self.input_bias, |
| 182 | self.recurrent_to_input_weights, |
| 183 | None, |
| 184 | Op.Sigmoid, |
| 185 | batch, |
| 186 | ) |
| 187 | forget_gate = self.calculate_gate( |
| 188 | f"forget_gate#{batch}.{time}", |
| 189 | feature, |
| 190 | output_state, |
| 191 | self.input_to_forget_weights, |
| 192 | self.forget_bias, |
| 193 | self.recurrent_to_forget_weights, |
| 194 | None, |
| 195 | Op.Sigmoid, |
| 196 | batch, |
| 197 | ) |
| 198 | cell_gate = self.calculate_gate( |
| 199 | f"cell_gate#{batch}.{time}", |
| 200 | feature, |
| 201 | output_state, |
| 202 | self.input_to_cell_weights, |
| 203 | self.cell_bias, |
| 204 | self.recurrent_to_cell_weights, |
| 205 | None, |
| 206 | Op.Tanh, |
| 207 | batch, |
| 208 | ) |
| 209 | cell_state = self.calculate_cell_state(cell_state, input_gate, forget_gate, cell_gate, time, batch) |
| 210 | output_gate = self.calculate_gate( |
| 211 | f"output_gate#{batch}.{time}", |
| 212 | feature, |
| 213 | output_state, |
| 214 | self.input_to_output_weights, |
| 215 | self.output_bias, |
| 216 | self.recurrent_to_output_weights, |
| 217 | None, |
| 218 | Op.Sigmoid, |
| 219 | batch, |
| 220 | ) |
| 221 | output_state = self.calculate_output_state(output_gate, cell_state, time, batch) |
| 222 | return (output_state, cell_state) |
| 223 | |
| 224 | def calculate_gate( |
| 225 | self, |
| 226 | name: str, |
| 227 | input: Tensor, |
| 228 | state: Tensor, |
| 229 | input_weights: Tensor, |
| 230 | input_bias: Tensor, |
| 231 | recurrent_weights: Tensor, |
| 232 | recurrent_bias: Tensor, |
| 233 | activation: Op, |
| 234 | batch: int = 0, |
| 235 | ): |
| 236 | """Generate a gate for the provided input and weights""" |
| 237 | # Activation( Add( FC(input), FC(output state) ) ) |
| 238 | # Setup fullyconnected quantization |
| 239 | q_fc = QuantizationParameters() |
| 240 | q_fc.scale_f32 = np.float32(2**-12) |
| 241 | q_fc.zero_point = 0 |
| 242 | # Create fullyconnected |
| 243 | in_fc = create_fullyconnected(f"{name}:{input.name}_fc", input, input_weights, input_bias, q_fc, False) |
| 244 | re_fc = create_fullyconnected(f"{name}:{state.name}_fc", state, recurrent_weights, recurrent_bias, q_fc, False) |
| 245 | self.get_state(re_fc, batch) |
| 246 | # Change fullyconnected ofm data type |
| 247 | in_fc.ofm.dtype = DataType.int16 |
| 248 | re_fc.ofm.dtype = DataType.int16 |
| 249 | # Setup add quantization |
| 250 | q_add = q_fc.clone() |
Fredrik Svedberg | 5fd155e | 2023-05-08 16:13:43 +0200 | [diff] [blame] | 251 | q_add.scale_f32 = Q0_15_SCALE |
Fredrik Svedberg | 0ac0804 | 2023-04-11 22:35:04 +0200 | [diff] [blame] | 252 | # Create add + activation |
| 253 | add = create_add(f"{name}_add", in_fc.ofm, re_fc.ofm, q_add, ActivationFunction(activation)) |
| 254 | if activation is Op.Sigmoid: |
| 255 | # For Sigmoid we need to set the activation min/max values to match the possible range |
| 256 | # in the reference. The values below are the quantized min/max values that the reference |
| 257 | # can achive for the LUT based Sigmoid/Logistic. (The NPU does however have a larger range |
| 258 | # due to intermediate higher precision.) |
| 259 | # The quantized min/max values are divided by the effective output scale 0x3000 (3<<12) used for |
| 260 | # elementwise operations with fused Tanh/Sigmoid activations (to get correct scaling before the |
| 261 | # fused activation function). This will yield the dequantized min/max values which are later |
| 262 | # quantized again by the command stream generator. |
| 263 | add.activation.max = 32757 / 0x3000 |
| 264 | add.activation.min = 11 / 0x3000 |
| 265 | # Add to debug database |
| 266 | DebugDatabase.add_optimised(self.op, in_fc) |
| 267 | DebugDatabase.add_optimised(self.op, re_fc) |
| 268 | DebugDatabase.add_optimised(self.op, add) |
| 269 | return add.ofm |
| 270 | |
| 271 | def calculate_cell_state( |
| 272 | self, cell_state: Tensor, input_gate: Tensor, forget_gate: Tensor, cell_gate: Tensor, time: int, batch: int = 0 |
| 273 | ): |
| 274 | """Update the cell state from the provided gate output""" |
| 275 | # Clip( Add( Mul(cell state, forget gate), Mul(input gate, cell gate) ) ) |
| 276 | base_name = f"cell_state#{batch}.{time}" |
| 277 | # Cell scale |
| 278 | cell_scale = cell_state.quantization.scale_f32 |
| 279 | # Create mul(cell_state, forget_gate) |
| 280 | mul_cf = create_mul(f"{base_name}_cf_mul", cell_state, forget_gate, cell_state.quantization) |
| 281 | self.get_state(mul_cf, batch) |
| 282 | # Calculate explicit scales to match reference |
| 283 | multiplier, shift = elementwise_mul_scale(np.double(cell_scale), np.double(Q0_15_SCALE), np.double(cell_scale)) |
| 284 | mul_cf.explicit_scaling = ExplicitScaling(False, [shift], [multiplier]) |
| 285 | # Create mul(cell_gate, input_gate) |
| 286 | mul_ci = create_mul(f"{base_name}_ci_mul", cell_gate, input_gate, cell_state.quantization) |
| 287 | # Calculate explicit scales to match reference |
| 288 | multiplier, shift = elementwise_mul_scale(np.double(Q0_15_SCALE), np.double(Q0_15_SCALE), np.double(cell_scale)) |
| 289 | mul_ci.explicit_scaling = ExplicitScaling(False, [shift], [multiplier]) |
| 290 | # Setup cell clip |
| 291 | activation = None if self.cell_clip == 0 else ActivationFunction(Op.Clip) |
| 292 | if activation: |
| 293 | activation.max = self.cell_clip |
| 294 | activation.min = -self.cell_clip |
| 295 | # Create add + activation |
| 296 | add = create_add(f"{base_name}_add", mul_cf.ofm, mul_ci.ofm, cell_state.quantization, activation) |
| 297 | add.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) |
| 298 | # Save new state |
| 299 | self.put_state(add, Lstm.State.CELL, batch) |
| 300 | # Add to debug database |
| 301 | DebugDatabase.add_optimised(self.op, mul_cf) |
| 302 | DebugDatabase.add_optimised(self.op, mul_ci) |
| 303 | DebugDatabase.add_optimised(self.op, add) |
| 304 | return add.ofm |
| 305 | |
| 306 | def calculate_output_state(self, output_gate: Tensor, cell_state: Tensor, time: int, batch: int): |
| 307 | """Generate the output state from the provided gate output""" |
| 308 | # Mul( Tanh(cell state), output gate ) |
| 309 | base_name = f"output_state#{batch}.{time}" |
| 310 | # Setup tanh quantization |
| 311 | q_out_tanh = QuantizationParameters() |
Fredrik Svedberg | 5fd155e | 2023-05-08 16:13:43 +0200 | [diff] [blame] | 312 | q_out_tanh.scale_f32 = Q0_15_SCALE |
Fredrik Svedberg | 0ac0804 | 2023-04-11 22:35:04 +0200 | [diff] [blame] | 313 | q_out_tanh.zero_point = 0 |
| 314 | # Create tanh(cell state) |
| 315 | tanh = create_fused_activation(Op.Tanh, f"{base_name}_tanh", cell_state, q_out_tanh) |
| 316 | self.get_state(tanh, batch) |
| 317 | # Create Mul( Tanh(cell state), output gate ) |
| 318 | q_mul = self.output_state.quantization |
| 319 | mul = create_mul(f"{base_name}_mul", tanh.ofm, output_gate, q_mul, dtype=self.op.ifm.dtype) |
| 320 | # Use explicit scaling to match reference, the following line would have been the preferred way |
| 321 | # mul.forced_output_quantization = self.hidden_quantization |
| 322 | out_scale = self.hidden_quantization.scale_f32 |
| 323 | multiplier, shift = elementwise_mul_scale(np.double(Q0_15_SCALE), np.double(Q0_15_SCALE), np.double(out_scale)) |
| 324 | mul.explicit_scaling = ExplicitScaling(False, [shift], [multiplier]) |
| 325 | # Save new state |
| 326 | self.put_state(mul, Lstm.State.OUTPUT, batch) |
| 327 | # Add to debug database |
| 328 | DebugDatabase.add_optimised(self.op, tanh) |
| 329 | DebugDatabase.add_optimised(self.op, mul) |
| 330 | return mul.ofm |
| 331 | |
| 332 | def state(self, state_type: State) -> Tensor: |
| 333 | """Get state tensor from type""" |
| 334 | return self.output_state if state_type == Lstm.State.OUTPUT else self.cell_state |
| 335 | |
| 336 | # Dimensions |
| 337 | @property |
| 338 | def n_feature(self) -> int: |
| 339 | return self.op.ifm.shape[-1] |
| 340 | |
| 341 | @property |
| 342 | def n_time(self) -> int: |
| 343 | return self.op.ifm.shape[0 if self.time_major else 1] |
| 344 | |
| 345 | @property |
| 346 | def n_batch(self) -> int: |
| 347 | return self.op.ifm.shape[1 if self.time_major else 0] |
| 348 | |
| 349 | # Attributes |
| 350 | @property |
| 351 | def cell_clip(self) -> int: |
| 352 | return self.op.attrs.get("cell_clip", 0) |
| 353 | |
| 354 | @property |
| 355 | def projection_clip(self) -> int: |
| 356 | return self.op.attrs.get("proj_clip", 0) |
| 357 | |
| 358 | @property |
| 359 | def time_major(self) -> bool: |
| 360 | return self.op.attrs.get("time_major", False) |
| 361 | |
| 362 | # Hidden (intermediate) |
| 363 | @property |
| 364 | def hidden_quantization(self) -> QuantizationParameters: |
| 365 | return self.op.intermediates[4].quantization |
| 366 | |
| 367 | # Input weights |
| 368 | @property |
| 369 | def input_to_input_weights(self) -> Tensor: |
| 370 | return self.op.inputs[1] |
| 371 | |
| 372 | @property |
| 373 | def input_to_forget_weights(self) -> Tensor: |
| 374 | return self.op.inputs[2] |
| 375 | |
| 376 | @property |
| 377 | def input_to_cell_weights(self) -> Tensor: |
| 378 | return self.op.inputs[3] |
| 379 | |
| 380 | @property |
| 381 | def input_to_output_weights(self) -> Tensor: |
| 382 | return self.op.inputs[4] |
| 383 | |
| 384 | # Recurrent weights |
| 385 | @property |
| 386 | def recurrent_to_input_weights(self) -> Tensor: |
| 387 | return self.op.inputs[5] |
| 388 | |
| 389 | @property |
| 390 | def recurrent_to_forget_weights(self) -> Tensor: |
| 391 | return self.op.inputs[6] |
| 392 | |
| 393 | @property |
| 394 | def recurrent_to_cell_weights(self) -> Tensor: |
| 395 | return self.op.inputs[7] |
| 396 | |
| 397 | @property |
| 398 | def recurrent_to_output_weights(self) -> Tensor: |
| 399 | return self.op.inputs[8] |
| 400 | |
| 401 | # Peephole weights |
| 402 | @property |
| 403 | def cell_to_input_weights(self) -> Tensor: |
| 404 | return self.op.inputs[9] |
| 405 | |
| 406 | @property |
| 407 | def cell_to_forget_weights(self) -> Tensor: |
| 408 | return self.op.inputs[10] |
| 409 | |
| 410 | @property |
| 411 | def cell_to_output_weights(self) -> Tensor: |
| 412 | return self.op.inputs[11] |
| 413 | |
| 414 | # Bias tensors |
| 415 | @property |
| 416 | def input_bias(self) -> Tensor: |
| 417 | return self.op.inputs[12] |
| 418 | |
| 419 | @property |
| 420 | def forget_bias(self) -> Tensor: |
| 421 | return self.op.inputs[13] |
| 422 | |
| 423 | @property |
| 424 | def cell_bias(self) -> Tensor: |
| 425 | return self.op.inputs[14] |
| 426 | |
| 427 | @property |
| 428 | def output_bias(self) -> Tensor: |
| 429 | return self.op.inputs[15] |
| 430 | |
| 431 | # Projection tensors |
| 432 | @property |
| 433 | def projection_weights(self) -> Tensor: |
| 434 | return self.op.inputs[16] |
| 435 | |
| 436 | @property |
| 437 | def projection_bias(self) -> Tensor: |
| 438 | return self.op.inputs[17] |
| 439 | |
| 440 | # State tensors (variable) |
| 441 | @property |
| 442 | def output_state(self) -> Tensor: |
| 443 | return self.op.inputs[Lstm.State.OUTPUT.value] |
| 444 | |
| 445 | @property |
| 446 | def cell_state(self) -> Tensor: |
| 447 | return self.op.inputs[Lstm.State.CELL.value] |