blob: ecec58f4828c24a51a3e7c97c5909526ed904c1e [file] [log] [blame]
Fredrik Svedberg0ac08042023-04-11 22:35:04 +02001# SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Contains implementation of UnidirectionalSequenceLstm graph optimisation.
19from enum import Enum
20from typing import Tuple
21
22import numpy as np
23
24from .data_type import DataType
25from .debug_database import DebugDatabase
26from .graph_optimiser_util import create_avg_pool_for_concat
27from .operation import ActivationFunction
28from .operation import ExplicitScaling
29from .operation import Op
30from .operation import Operation
31from .operation_util import create_add
32from .operation_util import create_fullyconnected
33from .operation_util import create_fused_activation
34from .operation_util import create_mul
35from .scaling import elementwise_mul_scale
36from .shape4d import Shape4D
37from .tensor import QuantizationParameters
38from .tensor import Tensor
39
Fredrik Svedberg5fd155e2023-05-08 16:13:43 +020040Q0_15_SCALE = np.float32(2**-15)
Fredrik Svedberg0ac08042023-04-11 22:35:04 +020041"""Q0.15 scale like the reference defines it"""
42
43
44class Lstm:
45 """Lstm graph optimisation.
46
47 Unrolls a UNIDIRECTIONAL_SEQUENCE_LSTM operation into its basic operations.
48
49 Usage:
50
51 unrolled_op = Lstm(op).get_graph()
52 """
53
54 class State(Enum):
55 """States (variable tensors)"""
56
57 OUTPUT = 18 # Value = tensor index
58 CELL = 19 # Value = tensor index
59
60 def __init__(self, op):
61 self.op = op
62
63 def get_graph(self) -> Operation:
64 """Return the generated graph implementation"""
65 self.op.ofm.ops = []
66 if self.time_major:
67 output_state = self.get_initial_state(Lstm.State.OUTPUT)
68 cell_state = self.get_initial_state(Lstm.State.CELL)
69 for time in range(self.n_time):
70 feature = self.get_feature(time)
71 output_state, cell_state = self.lstm_step(feature, output_state, cell_state, time)
72 op = self.put_ofm(output_state, time)
73 else:
74 for batch in range(self.n_batch):
75 output_state = self.get_initial_state(Lstm.State.OUTPUT, batch)
76 cell_state = self.get_initial_state(Lstm.State.CELL, batch)
77 for time in range(self.n_time):
78 feature = self.get_feature(time, batch)
79 output_state, cell_state = self.lstm_step(feature, output_state, cell_state, time, batch)
80 op = self.put_ofm(output_state, time, batch)
81 return op
82
83 def get_feature(self, time: int, batch: int = 0) -> Tensor:
84 """Get input feature for provided time and batch"""
85 feature = self.op.ifm.clone(f"_feature#{batch}.{time}")
86 feature.set_all_shapes([self.n_batch if self.time_major else 1, self.n_feature])
87 op = Operation(Op.SplitSliceRead, feature.name)
88 op.add_input_tensor(self.op.ifm)
89 op.set_output_tensor(feature)
90 op.set_ifm_ofm_shapes()
91 offset = [time, 0, 0] if self.time_major else [batch, time, 0]
92 op.read_offsets[0] = Shape4D.from_list(offset, 0)
93 op.read_shapes[0] = op.ofm_shapes[0]
94 DebugDatabase.add_optimised(self.op, op)
95 return feature
96
97 def get_initial_state(self, state_type: State, batch: int = 0) -> Tensor:
98 """Get state tensor for provided state type and batch"""
99 state = self.state(state_type)
100 if self.time_major:
101 # For time major just return the 2D state, since all batches
102 # are calculated at the same time
103 return state
104 else:
105 # For non time major return one batch of the 2D state
106 # by setting the read offset to the provided batch
107
108 # The cloned state tensor will share equivalence id and buffer
109 # with the variable state tensor
110 n_state = state.shape[-1]
111 state_ofm = state.clone(f"_state#{batch}")
112 # Set shape to be one batch
113 state_ofm.set_all_shapes([1, n_state])
114 # Create the op for reading one batch of the state
115 # (will be optimised away at a later stage)
116 op = Operation(Op.SplitSliceRead, state_ofm.name)
117 op.add_input_tensor(state)
118 op.set_output_tensor(state_ofm)
119 op.set_ifm_ofm_shapes()
120 # Set the read offset to the provided batch
121 op.read_offsets[0] = Shape4D.from_list([batch, 0], 0)
122 # Set the read shape to one batch, see above
123 op.read_shapes[0] = op.ofm_shapes[0]
124 DebugDatabase.add_optimised(self.op, op)
125 return state_ofm
126
127 def get_state(self, op: Operation, batch: int = 0) -> Operation:
128 """Setup the correct read offset for reading the state from
129 a variable tensor state"""
130 if not self.time_major and self.n_batch > 1:
131 op.read_offsets[0] = Shape4D.from_list([batch, 0], 0)
132 op.read_shapes[0] = Shape4D(op.ifm.shape)
133 op.ifm_shapes[0] = Shape4D([self.n_batch, op.ifm.shape[-1]])
134 return op
135
136 def put_state(self, op: Operation, state_type: State, batch: int = 0) -> Operation:
137 """Save the state for the provided batch by pointing the operations
138 ofm to the variable state tensor"""
139 # The create op functions always return 4D shape, however the state
140 # should have 2D shape for correct operation
141 op.ofm.shape = op.ofm.shape[-2:]
142 # Get state from type
143 state = self.state(state_type)
144 # By using the same equivalence_id the backing buffer for the ofm
145 # tensor will be the state variable tensor buffer
146 op.ofm.equivalence_id = state.equivalence_id
147 # Set memory function which will make the tensor be in linear format
148 # just as the state variable tensor
149 op.memory_function = Op.VariableTensorWrite
150 # Set the batch write offset into the state tensor buffer unless
151 # time_major mode when all batches are written at once
152 if not self.time_major:
153 op.write_offset = Shape4D.from_list([batch, 0], 0)
154 op.write_shape = Shape4D(op.ofm.shape)
155 op.ofm_shapes = [Shape4D(state.shape)]
156 DebugDatabase.add_optimised(self.op, op)
157 return op
158
159 def put_ofm(self, state: Tensor, time: int, batch: int = 0) -> Operation:
160 """Save the output state for the provided batch and time to OFM"""
161 name = f"{self.op.ofm.name}#{batch}.{time}"
162 offset = Shape4D.from_list([time, 0, 0] if self.time_major else [batch, time, 0], 0)
163 op = create_avg_pool_for_concat(self.op, name, state, Shape4D(state.shape), offset)
164 # The provided state tensor use the output state tensors buffer, so unless
165 # time_major mode we need to set the correct batch read offset
166 if not self.time_major:
167 op.read_offsets[0] = Shape4D.from_list([batch, 0], 0)
168 op.read_shapes[0] = Shape4D(state.shape)
169 op.ifm_shapes[0] = Shape4D(self.output_state.shape)
170 return op
171
172 def lstm_step(
173 self, feature: Tensor, output_state: Tensor, cell_state: Tensor, time: int, batch: int = 0
174 ) -> Tuple[Tensor, Tensor]:
175 """Generate one step of the LSTM implementation for the provided feature, batch and time"""
176 input_gate = self.calculate_gate(
177 f"input_gate#{batch}.{time}",
178 feature,
179 output_state,
180 self.input_to_input_weights,
181 self.input_bias,
182 self.recurrent_to_input_weights,
183 None,
184 Op.Sigmoid,
185 batch,
186 )
187 forget_gate = self.calculate_gate(
188 f"forget_gate#{batch}.{time}",
189 feature,
190 output_state,
191 self.input_to_forget_weights,
192 self.forget_bias,
193 self.recurrent_to_forget_weights,
194 None,
195 Op.Sigmoid,
196 batch,
197 )
198 cell_gate = self.calculate_gate(
199 f"cell_gate#{batch}.{time}",
200 feature,
201 output_state,
202 self.input_to_cell_weights,
203 self.cell_bias,
204 self.recurrent_to_cell_weights,
205 None,
206 Op.Tanh,
207 batch,
208 )
209 cell_state = self.calculate_cell_state(cell_state, input_gate, forget_gate, cell_gate, time, batch)
210 output_gate = self.calculate_gate(
211 f"output_gate#{batch}.{time}",
212 feature,
213 output_state,
214 self.input_to_output_weights,
215 self.output_bias,
216 self.recurrent_to_output_weights,
217 None,
218 Op.Sigmoid,
219 batch,
220 )
221 output_state = self.calculate_output_state(output_gate, cell_state, time, batch)
222 return (output_state, cell_state)
223
224 def calculate_gate(
225 self,
226 name: str,
227 input: Tensor,
228 state: Tensor,
229 input_weights: Tensor,
230 input_bias: Tensor,
231 recurrent_weights: Tensor,
232 recurrent_bias: Tensor,
233 activation: Op,
234 batch: int = 0,
235 ):
236 """Generate a gate for the provided input and weights"""
237 # Activation( Add( FC(input), FC(output state) ) )
238 # Setup fullyconnected quantization
239 q_fc = QuantizationParameters()
240 q_fc.scale_f32 = np.float32(2**-12)
241 q_fc.zero_point = 0
242 # Create fullyconnected
243 in_fc = create_fullyconnected(f"{name}:{input.name}_fc", input, input_weights, input_bias, q_fc, False)
244 re_fc = create_fullyconnected(f"{name}:{state.name}_fc", state, recurrent_weights, recurrent_bias, q_fc, False)
245 self.get_state(re_fc, batch)
246 # Change fullyconnected ofm data type
247 in_fc.ofm.dtype = DataType.int16
248 re_fc.ofm.dtype = DataType.int16
249 # Setup add quantization
250 q_add = q_fc.clone()
Fredrik Svedberg5fd155e2023-05-08 16:13:43 +0200251 q_add.scale_f32 = Q0_15_SCALE
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200252 # Create add + activation
253 add = create_add(f"{name}_add", in_fc.ofm, re_fc.ofm, q_add, ActivationFunction(activation))
254 if activation is Op.Sigmoid:
255 # For Sigmoid we need to set the activation min/max values to match the possible range
256 # in the reference. The values below are the quantized min/max values that the reference
257 # can achive for the LUT based Sigmoid/Logistic. (The NPU does however have a larger range
258 # due to intermediate higher precision.)
259 # The quantized min/max values are divided by the effective output scale 0x3000 (3<<12) used for
260 # elementwise operations with fused Tanh/Sigmoid activations (to get correct scaling before the
261 # fused activation function). This will yield the dequantized min/max values which are later
262 # quantized again by the command stream generator.
263 add.activation.max = 32757 / 0x3000
264 add.activation.min = 11 / 0x3000
265 # Add to debug database
266 DebugDatabase.add_optimised(self.op, in_fc)
267 DebugDatabase.add_optimised(self.op, re_fc)
268 DebugDatabase.add_optimised(self.op, add)
269 return add.ofm
270
271 def calculate_cell_state(
272 self, cell_state: Tensor, input_gate: Tensor, forget_gate: Tensor, cell_gate: Tensor, time: int, batch: int = 0
273 ):
274 """Update the cell state from the provided gate output"""
275 # Clip( Add( Mul(cell state, forget gate), Mul(input gate, cell gate) ) )
276 base_name = f"cell_state#{batch}.{time}"
277 # Cell scale
278 cell_scale = cell_state.quantization.scale_f32
279 # Create mul(cell_state, forget_gate)
280 mul_cf = create_mul(f"{base_name}_cf_mul", cell_state, forget_gate, cell_state.quantization)
281 self.get_state(mul_cf, batch)
282 # Calculate explicit scales to match reference
283 multiplier, shift = elementwise_mul_scale(np.double(cell_scale), np.double(Q0_15_SCALE), np.double(cell_scale))
284 mul_cf.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
285 # Create mul(cell_gate, input_gate)
286 mul_ci = create_mul(f"{base_name}_ci_mul", cell_gate, input_gate, cell_state.quantization)
287 # Calculate explicit scales to match reference
288 multiplier, shift = elementwise_mul_scale(np.double(Q0_15_SCALE), np.double(Q0_15_SCALE), np.double(cell_scale))
289 mul_ci.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
290 # Setup cell clip
291 activation = None if self.cell_clip == 0 else ActivationFunction(Op.Clip)
292 if activation:
293 activation.max = self.cell_clip
294 activation.min = -self.cell_clip
295 # Create add + activation
296 add = create_add(f"{base_name}_add", mul_cf.ofm, mul_ci.ofm, cell_state.quantization, activation)
297 add.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1])
298 # Save new state
299 self.put_state(add, Lstm.State.CELL, batch)
300 # Add to debug database
301 DebugDatabase.add_optimised(self.op, mul_cf)
302 DebugDatabase.add_optimised(self.op, mul_ci)
303 DebugDatabase.add_optimised(self.op, add)
304 return add.ofm
305
306 def calculate_output_state(self, output_gate: Tensor, cell_state: Tensor, time: int, batch: int):
307 """Generate the output state from the provided gate output"""
308 # Mul( Tanh(cell state), output gate )
309 base_name = f"output_state#{batch}.{time}"
310 # Setup tanh quantization
311 q_out_tanh = QuantizationParameters()
Fredrik Svedberg5fd155e2023-05-08 16:13:43 +0200312 q_out_tanh.scale_f32 = Q0_15_SCALE
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200313 q_out_tanh.zero_point = 0
314 # Create tanh(cell state)
315 tanh = create_fused_activation(Op.Tanh, f"{base_name}_tanh", cell_state, q_out_tanh)
316 self.get_state(tanh, batch)
317 # Create Mul( Tanh(cell state), output gate )
318 q_mul = self.output_state.quantization
319 mul = create_mul(f"{base_name}_mul", tanh.ofm, output_gate, q_mul, dtype=self.op.ifm.dtype)
320 # Use explicit scaling to match reference, the following line would have been the preferred way
321 # mul.forced_output_quantization = self.hidden_quantization
322 out_scale = self.hidden_quantization.scale_f32
323 multiplier, shift = elementwise_mul_scale(np.double(Q0_15_SCALE), np.double(Q0_15_SCALE), np.double(out_scale))
324 mul.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
325 # Save new state
326 self.put_state(mul, Lstm.State.OUTPUT, batch)
327 # Add to debug database
328 DebugDatabase.add_optimised(self.op, tanh)
329 DebugDatabase.add_optimised(self.op, mul)
330 return mul.ofm
331
332 def state(self, state_type: State) -> Tensor:
333 """Get state tensor from type"""
334 return self.output_state if state_type == Lstm.State.OUTPUT else self.cell_state
335
336 # Dimensions
337 @property
338 def n_feature(self) -> int:
339 return self.op.ifm.shape[-1]
340
341 @property
342 def n_time(self) -> int:
343 return self.op.ifm.shape[0 if self.time_major else 1]
344
345 @property
346 def n_batch(self) -> int:
347 return self.op.ifm.shape[1 if self.time_major else 0]
348
349 # Attributes
350 @property
351 def cell_clip(self) -> int:
352 return self.op.attrs.get("cell_clip", 0)
353
354 @property
355 def projection_clip(self) -> int:
356 return self.op.attrs.get("proj_clip", 0)
357
358 @property
359 def time_major(self) -> bool:
360 return self.op.attrs.get("time_major", False)
361
362 # Hidden (intermediate)
363 @property
364 def hidden_quantization(self) -> QuantizationParameters:
365 return self.op.intermediates[4].quantization
366
367 # Input weights
368 @property
369 def input_to_input_weights(self) -> Tensor:
370 return self.op.inputs[1]
371
372 @property
373 def input_to_forget_weights(self) -> Tensor:
374 return self.op.inputs[2]
375
376 @property
377 def input_to_cell_weights(self) -> Tensor:
378 return self.op.inputs[3]
379
380 @property
381 def input_to_output_weights(self) -> Tensor:
382 return self.op.inputs[4]
383
384 # Recurrent weights
385 @property
386 def recurrent_to_input_weights(self) -> Tensor:
387 return self.op.inputs[5]
388
389 @property
390 def recurrent_to_forget_weights(self) -> Tensor:
391 return self.op.inputs[6]
392
393 @property
394 def recurrent_to_cell_weights(self) -> Tensor:
395 return self.op.inputs[7]
396
397 @property
398 def recurrent_to_output_weights(self) -> Tensor:
399 return self.op.inputs[8]
400
401 # Peephole weights
402 @property
403 def cell_to_input_weights(self) -> Tensor:
404 return self.op.inputs[9]
405
406 @property
407 def cell_to_forget_weights(self) -> Tensor:
408 return self.op.inputs[10]
409
410 @property
411 def cell_to_output_weights(self) -> Tensor:
412 return self.op.inputs[11]
413
414 # Bias tensors
415 @property
416 def input_bias(self) -> Tensor:
417 return self.op.inputs[12]
418
419 @property
420 def forget_bias(self) -> Tensor:
421 return self.op.inputs[13]
422
423 @property
424 def cell_bias(self) -> Tensor:
425 return self.op.inputs[14]
426
427 @property
428 def output_bias(self) -> Tensor:
429 return self.op.inputs[15]
430
431 # Projection tensors
432 @property
433 def projection_weights(self) -> Tensor:
434 return self.op.inputs[16]
435
436 @property
437 def projection_bias(self) -> Tensor:
438 return self.op.inputs[17]
439
440 # State tensors (variable)
441 @property
442 def output_state(self) -> Tensor:
443 return self.op.inputs[Lstm.State.OUTPUT.value]
444
445 @property
446 def cell_state(self) -> Tensor:
447 return self.op.inputs[Lstm.State.CELL.value]