blob: 44ce711f8a6c8f49d7951c8a43882d2cdf5f1674 [file] [log] [blame]
Johan Alfven9070f0f2023-02-07 13:01:03 +01001# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Tim Hall79d07d22020-04-27 18:20:16 +010017# Description:
18# Functions used to write to a TensorFlow Lite format file. Supports adding in file identifiers.
Tim Hall79d07d22020-04-27 18:20:16 +010019import flatbuffers
Diego Russoe8a10452020-04-21 17:39:10 +010020import flatbuffers.number_types as N
21import numpy as np
22from flatbuffers import encode
Diego Russoea6111a2020-04-14 18:41:58 +010023from flatbuffers.builder import UOffsetTFlags
24
William Isakssonea8c5372023-07-03 20:31:42 +000025from ._version import __version__
Michael McGeagh7a6f8432020-12-02 15:29:22 +000026from .errors import VelaError
Diego Russoe8a10452020-04-21 17:39:10 +010027from .nn_graph import PassPlacement
Louis Verhaardaee5d752020-09-30 09:01:52 +020028from .operation import Op
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +020029from .reader_util import align_inputs_indices
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020030from .tensor import MemType
Johan Alfvénb9f81592022-10-31 14:39:02 +010031from .tensor import shape_num_elements
Samuel Panijel6f4955a2021-06-10 13:40:03 +030032from .tensor import TensorPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010033from .tflite import Buffer
34from .tflite import Metadata
Diego Russoe8a10452020-04-21 17:39:10 +010035from .tflite import Model
36from .tflite import Operator
37from .tflite import OperatorCode
38from .tflite import QuantizationParameters
39from .tflite import SubGraph
40from .tflite import Tensor
41from .tflite_mapping import builtin_operator_inv_map
42from .tflite_mapping import BuiltinOperator
Diego Russoe8a10452020-04-21 17:39:10 +010043from .tflite_mapping import datatype_inv_map
Tim Hall2180a172023-03-10 18:11:34 +000044from .tflite_mapping import optype_to_builtintype
Diego Russoe8a10452020-04-21 17:39:10 +010045
Tim Hallffe8e282021-06-24 18:29:53 +010046# the python flatbuffer interface is missing a method to add in file identifier. patching it in here:
Tim Hall79d07d22020-04-27 18:20:16 +010047
48tflite_version = 3
49tflite_file_identifier = "TFL" + str(tflite_version)
50
51
Tim Hall79d07d22020-04-27 18:20:16 +010052def FinishWithFileIdentifier(self, rootTable, fid):
53 if fid is None or len(fid) != 4:
Michael McGeagh7a6f8432020-12-02 15:29:22 +000054 raise VelaError("FileIdentifier must be 4 chars")
Tim Hall79d07d22020-04-27 18:20:16 +010055
56 flags = N.Uint8Flags
57 prepSize = 4
58 self.Prep(self.minalign, prepSize + len(fid))
59 for i in range(3, -1, -1):
60 self.head = self.head - flags.bytewidth
61 encode.Write(flags.packer_type, self.Bytes, self.Head(), ord(fid[i]))
62
63 return self.Finish(rootTable)
64
65
66flatbuffers.Builder.FinishWithFileIdentifier = FinishWithFileIdentifier
67
68
69def make_vector(v):
70 try:
71 len(v)
72 return v
73 except TypeError:
74 return [v]
75
76
77class TFLiteSerialiser:
Johan Alfvén673683b2022-09-05 09:39:47 +020078
Johan Alfven10706362023-04-13 12:20:55 +020079 # The 0th buffer is always by default an empty buffer that can be used by tensors
80 # without any constant data
81 BUF_IDX_ZERO = 0
82 BUF_IDX_START = 1
Johan Alfvén673683b2022-09-05 09:39:47 +020083
Tim Hall79d07d22020-04-27 18:20:16 +010084 def __init__(self, nng):
85 self.builder = flatbuffers.Builder(0)
86 self.nng = nng
87
Johan Alfvén673683b2022-09-05 09:39:47 +020088 self.buf_idx = TFLiteSerialiser.BUF_IDX_START
Tim Hall79d07d22020-04-27 18:20:16 +010089 self.buffers_to_write = [] # have an empty array there
Johan Alfvén673683b2022-09-05 09:39:47 +020090 self.tensor_map_all = [] # Keep track of all subgraphs
91 self.tensor_map_sg = [] # Keep track of one subgraph
Tim Hall79d07d22020-04-27 18:20:16 +010092
Michael McGeaghf3e3ad72020-12-02 12:39:03 +000093 self.ops_to_ignore = (Op.Const, Op.Placeholder, Op.SubgraphInput)
Tim Hall79d07d22020-04-27 18:20:16 +010094
Tim Hall79d07d22020-04-27 18:20:16 +010095 self.subgraphs_to_write = [sg for sg in self.nng.subgraphs if sg.placement == PassPlacement.Cpu]
96
97 all_ops = []
98 for sg in self.subgraphs_to_write:
99 for ps in sg.passes:
100 for op in ps.ops:
101 if op.type not in self.ops_to_ignore:
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200102 # swap from nng input indexing to TensorFlow Lite input indexing
103 self.align_nng_inputs_to_tflite(op)
Tim Hall79d07d22020-04-27 18:20:16 +0100104 all_ops.append(op)
Johan Alfvenc4268bf2023-04-13 10:13:56 +0200105 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type == Op.FullyConnected:
Johan Alfvenc02eaa32023-08-22 11:02:47 +0200106 # Op is run on CPU, make sure the original weight and bias tensors are written back
Johan Alfvenc4268bf2023-04-13 10:13:56 +0200107 # instead of the cloned/reshaped (see tflite_reader)
108 for idx, inp in enumerate(op.inputs):
Johan Alfvenc02eaa32023-08-22 11:02:47 +0200109 if inp != op.ifm and inp is not None and inp.src_tensor is not None:
Johan Alfvenc4268bf2023-04-13 10:13:56 +0200110 op.inputs[idx] = inp.src_tensor
Tim Hall79d07d22020-04-27 18:20:16 +0100111
wilisa010a7d5ee2023-04-13 17:05:09 +0000112 # list of tuple(Op, string, op.version); the custom code is only used for 3rd party custom operators
113 self.operator_codes = sorted(set((op.type, op.attrs.get("custom_code", ""), op.version) for op in all_ops))
Tim Hall79d07d22020-04-27 18:20:16 +0100114 self.operator_code_map = {}
115
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200116 def align_nng_inputs_to_tflite(self, op):
117 from_indices = op.type.info.indices
118 _, _, to_indices = builtin_operator_inv_map[op.type]
119 op.inputs = align_inputs_indices(from_indices, to_indices, op.inputs)
120
Tim Hall79d07d22020-04-27 18:20:16 +0100121 def write_byte_vector(self, v, alignment=1):
122 builder = self.builder
123 builder.StartVector(1, len(v), alignment)
124 for e in v[::-1]:
125 builder.PrependByte(e)
erik.andersson@arm.com61f05d92022-09-27 12:06:32 +0200126 return builder.EndVector()
Tim Hall79d07d22020-04-27 18:20:16 +0100127
128 def write_int_vector(self, v):
129 builder = self.builder
130 builder.StartVector(4, len(v), 4)
131 for e in v[::-1]:
132 builder.PrependInt32(e)
erik.andersson@arm.com61f05d92022-09-27 12:06:32 +0200133 return builder.EndVector()
Tim Hall79d07d22020-04-27 18:20:16 +0100134
135 def write_long_vector(self, v):
136 builder = self.builder
137 builder.StartVector(8, len(v), 8)
138 for e in v[::-1]:
139 builder.PrependInt64(e)
erik.andersson@arm.com61f05d92022-09-27 12:06:32 +0200140 return builder.EndVector()
Tim Hall79d07d22020-04-27 18:20:16 +0100141
142 def write_float_vector(self, v):
143 builder = self.builder
144 builder.StartVector(4, len(v), 4)
145 for e in v[::-1]:
146 builder.PrependFloat32(e)
erik.andersson@arm.com61f05d92022-09-27 12:06:32 +0200147 return builder.EndVector()
Tim Hall79d07d22020-04-27 18:20:16 +0100148
149 def write_offset_vector(self, v):
150 builder = self.builder
151 builder.StartVector(4, len(v), 4)
152 for e in v[::-1]:
153 builder.PrependUOffsetTRelative(e)
erik.andersson@arm.com61f05d92022-09-27 12:06:32 +0200154 return builder.EndVector()
Tim Hall79d07d22020-04-27 18:20:16 +0100155
Tim Hallc8310b12020-06-17 14:53:11 +0100156 def assign_buffers_to_tensors(self, tensors, scratch_tensor):
157 if scratch_tensor is not None:
158 scratch_tensor_mem_area = scratch_tensor.mem_area
Tim Hall25f605c2020-05-18 18:04:26 +0100159 else:
160 scratch_tensor_mem_area = None # all tensors are initialised to MemArea.Unknown
161
Tim Hall79d07d22020-04-27 18:20:16 +0100162 buffer_map = {}
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200163
Tim Hall79d07d22020-04-27 18:20:16 +0100164 for tens in tensors:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200165 # Set buffer ids depending on allocation
Johan Alfven10706362023-04-13 12:20:55 +0200166 if tens.is_allocated_in_tensor_arena(scratch_tensor_mem_area) or tens.mem_type == MemType.Scratch_fast:
167 # Tensor allocated in the scratch areas, does not have any constant data and can
168 # therefore all point to the empty buffer (zero)
169 buffer_map[tens] = TFLiteSerialiser.BUF_IDX_ZERO
Tim Hall79d07d22020-04-27 18:20:16 +0100170 else:
Johan Alfvén673683b2022-09-05 09:39:47 +0200171 buffer_map[tens] = self.buf_idx
172 self.buf_idx += 1
Tim Hall79d07d22020-04-27 18:20:16 +0100173
Johan Alfvén673683b2022-09-05 09:39:47 +0200174 # Initialize/extend buffers_to_write to a length equal to number of buffers so
Tim Hall79d07d22020-04-27 18:20:16 +0100175 # they can be appended at the correct index during tensor serialization
Johan Alfvén673683b2022-09-05 09:39:47 +0200176 self.buffers_to_write += [None] * (self.buf_idx)
Tim Hall79d07d22020-04-27 18:20:16 +0100177
178 return buffer_map
179
wilisa010a7d5ee2023-04-13 17:05:09 +0000180 def serialise_operator_code(self, idx, op_type, custom_code, version):
Tim Hall79d07d22020-04-27 18:20:16 +0100181 builder = self.builder
182 custom_code_offset = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200183 if op_type == Op.Custom:
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200184 tf_code, opt_serializer, _ = builtin_operator_inv_map[op_type]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200185 custom_code_offset = builder.CreateString(custom_code)
Tim Hall79d07d22020-04-27 18:20:16 +0100186 else:
Tim Halle9194df2020-08-04 20:37:01 +0100187 assert (
Louis Verhaardaee5d752020-09-30 09:01:52 +0200188 op_type in builtin_operator_inv_map
189 ), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(op_type)
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200190 tf_code, opt_serializer, _ = builtin_operator_inv_map[op_type]
Tim Hall79d07d22020-04-27 18:20:16 +0100191
Tim Hallb2183762021-01-25 21:42:56 +0000192 if op_type == Op.CustomNpuOp:
Tim Halle9194df2020-08-04 20:37:01 +0100193 assert (
Tim Hallb2183762021-01-25 21:42:56 +0000194 tf_code == BuiltinOperator.CUSTOM
Tim Halle9194df2020-08-04 20:37:01 +0100195 ), "Vela only supports serialising NpuOp operators as TensorFlow Lite Custom operators"
Tim Hall79d07d22020-04-27 18:20:16 +0100196 custom_code_offset = builder.CreateString("ethos-u")
197
Tim Hallb2183762021-01-25 21:42:56 +0000198 # there can be multiple different types of 3rd party custom operators (i.e. non-"ethos-u" ones). therefore we
199 # need to add an extra level of indirection to this particular entry in the operator_code_map to allow for the
200 # correct lookup later on
201 if op_type == Op.Custom:
202 if op_type not in self.operator_code_map:
203 self.operator_code_map[op_type] = {}
204 self.operator_code_map[op_type][custom_code] = (idx, tf_code, opt_serializer)
205 else:
206 self.operator_code_map[op_type] = (idx, tf_code, opt_serializer)
Tim Hall79d07d22020-04-27 18:20:16 +0100207
208 OperatorCode.OperatorCodeStart(builder)
Tim Hall42abec12021-02-04 21:31:57 +0000209 OperatorCode.OperatorCodeAddDeprecatedBuiltinCode(builder, tf_code if tf_code < 127 else 127)
Tim Hall79d07d22020-04-27 18:20:16 +0100210 OperatorCode.OperatorCodeAddBuiltinCode(builder, tf_code)
wilisa010a7d5ee2023-04-13 17:05:09 +0000211 OperatorCode.OperatorCodeAddVersion(builder, version)
Tim Hall79d07d22020-04-27 18:20:16 +0100212 if custom_code_offset is not None:
213 OperatorCode.OperatorCodeAddCustomCode(builder, custom_code_offset)
214
215 return OperatorCode.OperatorCodeEnd(builder)
216
217 def serialise_quantization_parameters(self, quant):
218 builder = self.builder
219
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100220 qp = None
Tim Hall79d07d22020-04-27 18:20:16 +0100221 min = None
222 max = None
223 scale = None
224 zero_point = None
225 if quant is not None:
226 if quant.min is not None:
227 min = self.write_float_vector(make_vector(quant.min))
228 if quant.max is not None:
229 max = self.write_float_vector(make_vector(quant.max))
230 if quant.scale_f32 is not None:
231 scale = self.write_float_vector(make_vector(quant.scale_f32))
232 if quant.zero_point is not None:
233 zero_point = self.write_long_vector(make_vector(quant.zero_point))
234
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100235 QuantizationParameters.QuantizationParametersStart(builder)
236 if min is not None:
237 QuantizationParameters.QuantizationParametersAddMin(builder, min)
238 if max is not None:
239 QuantizationParameters.QuantizationParametersAddMax(builder, max)
240 if scale is not None:
241 QuantizationParameters.QuantizationParametersAddScale(builder, scale)
242 if zero_point is not None:
243 QuantizationParameters.QuantizationParametersAddZeroPoint(builder, zero_point)
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100244 if quant.quant_dim is not None:
245 QuantizationParameters.QuantizationParametersAddQuantizedDimension(builder, quant.quant_dim)
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100246 qp = QuantizationParameters.QuantizationParametersEnd(builder)
247
248 return qp
Tim Hall79d07d22020-04-27 18:20:16 +0100249
250 def serialise_tensor(self, tens):
251 builder = self.builder
Johan Alfvénb9f81592022-10-31 14:39:02 +0100252 if shape_num_elements(tens.original_shape) != shape_num_elements(tens.shape):
253 # shapes have changed size, therefore assume that the latest (modified) shape is correct
254 tens_shape = tens.shape
255 else:
256 # shapes have not changed size, therefore the original shape is valid
257 tens_shape = tens.original_shape
James Peet7519d502021-07-19 16:47:58 +0100258 values = tens.values
Tim Hall79d07d22020-04-27 18:20:16 +0100259
Tim Hall79d07d22020-04-27 18:20:16 +0100260 buf_id = self.buffer_map[tens]
Johan Alfven10706362023-04-13 12:20:55 +0200261 # Sanity check that if buffer 0 is used there must not be any data
262 assert not (buf_id == TFLiteSerialiser.BUF_IDX_ZERO and values is not None)
Tim Hall2f18e172023-04-06 21:01:58 +0100263 self.buffers_to_write[buf_id] = None if values is None else values.flatten().view(np.uint8)
Tim Hall79d07d22020-04-27 18:20:16 +0100264
265 shape = self.write_int_vector(tens_shape)
266
267 name = builder.CreateString(tens.name)
268 quant = self.serialise_quantization_parameters(tens.quantization)
269
270 Tensor.TensorStart(builder)
271 Tensor.TensorAddShape(builder, shape)
272 Tensor.TensorAddType(builder, datatype_inv_map[tens.dtype])
273 # All tensors must have a valid backing buffer, even if it is empty.
274 # Empty buffers should be kept unique for TensorFlow Lite Micro
275 Tensor.TensorAddBuffer(builder, buf_id)
276 Tensor.TensorAddName(builder, name)
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100277 if quant is not None:
278 Tensor.TensorAddQuantization(builder, quant)
279 Tensor.TensorAddIsVariable(builder, tens.is_variable)
Tim Hall79d07d22020-04-27 18:20:16 +0100280
281 res = Tensor.TensorEnd(builder)
282 return res
283
284 def serialise_operator(self, op):
285 builder = self.builder
286
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100287 inputs_offset = self.write_int_vector(
Johan Alfvén673683b2022-09-05 09:39:47 +0200288 [self.tensor_map_sg[tens] if tens in self.tensor_map_sg else -1 for tens in op.inputs]
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100289 )
Michael McGeaghbb1b09e2020-08-19 11:24:17 +0100290 outputs_offset = self.write_int_vector(
Johan Alfvén673683b2022-09-05 09:39:47 +0200291 [self.tensor_map_sg[tens] for tens in op.outputs if tens in self.tensor_map_sg]
Michael McGeaghbb1b09e2020-08-19 11:24:17 +0100292 )
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100293 intermediates_offset = self.write_int_vector(
Johan Alfvén673683b2022-09-05 09:39:47 +0200294 [self.tensor_map_sg[tens] for tens in op.intermediates if tens in self.tensor_map_sg]
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100295 )
Tim Hall79d07d22020-04-27 18:20:16 +0100296
Tim Hallb2183762021-01-25 21:42:56 +0000297 if op.type == Op.Custom:
298 op_idx, tflop, opt_serializer = self.operator_code_map[op.type][op.attrs.get("custom_code", "")]
299 else:
300 op_idx, tflop, opt_serializer = self.operator_code_map[op.type]
Tim Hall79d07d22020-04-27 18:20:16 +0100301
302 builtin_opt_offset = None
303 custom_opt_offset = None
304 if opt_serializer is not None:
305 attrs = dict(op.attrs)
Tim Hall2180a172023-03-10 18:11:34 +0000306 if op.run_on_npu:
307 if "strides" in attrs:
308 attrs["stride_h"] = attrs["strides"][1]
309 attrs["stride_w"] = attrs["strides"][2]
310 if "ksize" in attrs:
311 attrs["filter_height"] = attrs["ksize"][1]
312 attrs["filter_width"] = attrs["ksize"][2]
313 if "dilation" in attrs:
314 attrs["dilation_h_factor"] = attrs["dilation"][1]
315 attrs["dilation_w_factor"] = attrs["dilation"][2]
316 if "channel_multiplier" in attrs:
317 attrs["depth_multiplier"] = attrs["channel_multiplier"]
Tim Hall2180a172023-03-10 18:11:34 +0000318 attrs["fused_activation_function"] = op.activation.op_type if op.activation is not None else None
Tim Hall79d07d22020-04-27 18:20:16 +0100319
Johan Alfven0426fe92023-05-15 11:22:48 +0200320 # Serialize VarHandleOptions (only op that have attributes with type String)
321 if "container" in attrs:
322 attrs["container"] = builder.CreateString(attrs["container"])
323 if "shared_name" in attrs:
324 attrs["shared_name"] = builder.CreateString(attrs["shared_name"])
325
Tim Hall79d07d22020-04-27 18:20:16 +0100326 builtin_opt_offset, custom_opt_offset = opt_serializer.serialize(builder, attrs)
327
Tim Hall2180a172023-03-10 18:11:34 +0000328 # report any missing attributes that could not be written during serialize().
329 # operators that have been created internally (i.e. not created as part of reading an input network) may not
330 # have the write error attribute
331 attribute_write_error = attrs.get("attribute_write_error", [])
332 if len(attribute_write_error) != 0:
333 print(
334 f"Warning: Could not write the following attributes to {optype_to_builtintype(op.type)}"
335 f" '{op.name}' {opt_serializer.name} field: {', '.join(attribute_write_error)}"
336 )
337
Tim Hall79d07d22020-04-27 18:20:16 +0100338 mutating_variable_inputs_offset = self.write_byte_vector([])
339 Operator.OperatorStart(builder)
340 Operator.OperatorAddOpcodeIndex(builder, op_idx)
341 Operator.OperatorAddInputs(builder, inputs_offset)
342 Operator.OperatorAddOutputs(builder, outputs_offset)
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100343 Operator.OperatorAddIntermediates(builder, intermediates_offset)
Tim Hall79d07d22020-04-27 18:20:16 +0100344
345 if builtin_opt_offset is not None:
William Isakssonf4a511f2023-11-22 22:27:58 +0100346 if opt_serializer.builtin_opt_type < 127:
347 Operator.OperatorAddBuiltinOptionsType(builder, opt_serializer.builtin_opt_type)
348 Operator.OperatorAddBuiltinOptions(builder, builtin_opt_offset)
349 else:
350 Operator.OperatorAddBuiltinOptions2Type(builder, opt_serializer.builtin_opt_type % 127)
351 Operator.OperatorAddBuiltinOptions2(builder, builtin_opt_offset)
Tim Hall79d07d22020-04-27 18:20:16 +0100352 if custom_opt_offset is not None:
353 Operator.OperatorAddCustomOptions(builder, custom_opt_offset)
354 Operator.OperatorAddCustomOptionsFormat(builder, opt_serializer.custom_opt_format)
355
356 Operator.OperatorAddMutatingVariableInputs(builder, mutating_variable_inputs_offset)
357 return Operator.OperatorEnd(builder)
358
Johan Alfvén673683b2022-09-05 09:39:47 +0200359 def serialise_subgraph(self, sg, name):
Tim Hall79d07d22020-04-27 18:20:16 +0100360 builder = self.builder
Tim Hall79d07d22020-04-27 18:20:16 +0100361 all_ops = []
Michael McGeagh515c9562020-09-02 15:52:43 +0100362 placeholder_ops = []
363
Tim Hall79d07d22020-04-27 18:20:16 +0100364 for ps in sg.passes:
365 for op in ps.ops:
366 if op.type not in self.ops_to_ignore:
367 all_ops.append(op)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200368 elif op.type == Op.Placeholder:
Michael McGeagh515c9562020-09-02 15:52:43 +0100369 placeholder_ops.append(op)
Tim Hall79d07d22020-04-27 18:20:16 +0100370
Johan Alfvén673683b2022-09-05 09:39:47 +0200371 # Make sure all original tensors are written back, special case for Ops
372 # with connected subgraphs. Even though not all inputs are used,
373 # the reference kernel expects all inputs to be in the tflite file.
374 # Since we traverse the graph starting with all outputs they are
375 # always added but if an input is not referenced it will not be added
376 # to an op.
377 tensor_set = set(sg.original_inputs)
378
Johan Alfven9070f0f2023-02-07 13:01:03 +0100379 # Remove any virtual outputs since they are only used internally when
380 # traversing the graph.
381 for tens in sg.virtual_outputs:
382 tens.ops[0].outputs = []
383 if tens in sg.output_tensors:
384 sg.output_tensors.remove(tens)
385
Michael McGeagh515c9562020-09-02 15:52:43 +0100386 # Add the tensors from all valid ops, as well as the tensors from placeholder ops
387 # This allows us to serialise tensors which arent attached to any specific ops,
388 # e.g. due to an empty graph containing no ops
389 for op in all_ops + placeholder_ops:
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100390 for tens in op.inputs + op.outputs + op.intermediates:
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200391 if tens is not None:
392 tensor_set.add(tens)
Tim Hall79d07d22020-04-27 18:20:16 +0100393
394 all_tensors = [tens for nm, idx, tens in sorted((tens.name, idx, tens) for idx, tens in enumerate(tensor_set))]
395
Samuel Panijel6f4955a2021-06-10 13:40:03 +0300396 scratch_tensors = [tens for tens in all_tensors if tens.purpose is TensorPurpose.Scratch]
Patrik Gustavsson3ab94522020-06-29 17:36:55 +0200397
Tim Hallc8310b12020-06-17 14:53:11 +0100398 if len(scratch_tensors) == 0:
399 scratch_tensor = None
400 else:
401 assert len(scratch_tensors) == 1, "Multiple scratch tensors"
402 scratch_tensor = scratch_tensors[0]
403
Johan Alfvén673683b2022-09-05 09:39:47 +0200404 self.tensor_map_sg = {tens: idx for idx, tens in enumerate(all_tensors)}
Tim Hallc8310b12020-06-17 14:53:11 +0100405 self.buffer_map = self.assign_buffers_to_tensors(all_tensors, scratch_tensor)
Johan Alfvén673683b2022-09-05 09:39:47 +0200406 self.tensor_map_all.append(self.tensor_map_sg)
Tim Hall79d07d22020-04-27 18:20:16 +0100407
408 tensors_offset = self.write_offset_vector([self.serialise_tensor(tens) for tens in all_tensors])
409
Tim Hall79d07d22020-04-27 18:20:16 +0100410 # Make sure the input_tensors haven't been modified
411 assert all(inp in sg.original_inputs for inp in sg.input_tensors)
Johan Alfvén673683b2022-09-05 09:39:47 +0200412 inputs = [self.tensor_map_sg[tens] for tens in sg.original_inputs if tens in self.tensor_map_sg]
Tim Hallc8310b12020-06-17 14:53:11 +0100413
Tim Hallc8310b12020-06-17 14:53:11 +0100414 inputs_offset = self.write_int_vector(inputs)
Michael McGeaghbb1b09e2020-08-19 11:24:17 +0100415 outputs_offset = self.write_int_vector(
Johan Alfvén673683b2022-09-05 09:39:47 +0200416 [self.tensor_map_sg[tens] for tens in sg.output_tensors if tens in self.tensor_map_sg]
Michael McGeaghbb1b09e2020-08-19 11:24:17 +0100417 )
Tim Hall79d07d22020-04-27 18:20:16 +0100418
419 operators_offset = self.write_offset_vector([self.serialise_operator(op) for op in all_ops])
420
421 SubGraph.SubGraphStart(builder)
422 SubGraph.SubGraphAddTensors(builder, tensors_offset)
423 SubGraph.SubGraphAddInputs(builder, inputs_offset)
424 SubGraph.SubGraphAddOutputs(builder, outputs_offset)
425
426 SubGraph.SubGraphAddOperators(builder, operators_offset)
Johan Alfvén673683b2022-09-05 09:39:47 +0200427 SubGraph.SubGraphAddName(builder, name)
Tim Hall79d07d22020-04-27 18:20:16 +0100428
429 return SubGraph.SubGraphEnd(builder)
430
431 def write_aligned_bytes(self, buf):
432 builder = self.builder
erik.andersson@arm.com61f05d92022-09-27 12:06:32 +0200433 builder.assertNotNested()
Tim Hall79d07d22020-04-27 18:20:16 +0100434 builder.nested = True
William Isakssonea8c5372023-07-03 20:31:42 +0000435 if isinstance(buf, str):
436 data = bytes(buf, "utf-8")
437 else:
438 data = bytes(buf)
Tim Hall79d07d22020-04-27 18:20:16 +0100439 length_bytes = UOffsetTFlags.py_type(len(data))
erik.andersson@arm.com61f05d92022-09-27 12:06:32 +0200440 builder.vectorNumElems = length_bytes
Tim Hall79d07d22020-04-27 18:20:16 +0100441 builder.Prep(16, length_bytes) # Reserve aligned storage
442 builder.head = UOffsetTFlags.py_type(builder.Head() - length_bytes) # Update FlatBuffer internal pointer
443 builder.Bytes[builder.Head() : builder.Head() + length_bytes] = data # Assign bytes to aligned area
erik.andersson@arm.com61f05d92022-09-27 12:06:32 +0200444 return builder.EndVector()
Tim Hall79d07d22020-04-27 18:20:16 +0100445
446 def serialise_buffer(self, buf):
447 builder = self.builder
448 data = None
449 if buf is not None:
450 data = self.write_aligned_bytes(buf)
451 Buffer.BufferStart(builder)
452 if data is not None:
453 Buffer.BufferAddData(builder, data)
454 return Buffer.BufferEnd(builder)
455
456 def serialise_metadata(self, metadata):
457 builder = self.builder
458 name = builder.CreateString(metadata[0])
459
460 Metadata.MetadataStart(builder)
461 Metadata.MetadataAddName(builder, name)
462 Metadata.MetadataAddBuffer(builder, metadata[1])
463
464 return Metadata.MetadataEnd(builder)
465
466 def serialise_model(self):
467 builder = self.builder
468 operator_code_offset = self.write_offset_vector(
wilisa010a7d5ee2023-04-13 17:05:09 +0000469 [
470 self.serialise_operator_code(idx, optype, code, version)
471 for idx, (optype, code, version) in enumerate(self.operator_codes)
472 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100473 )
474
William Isakssonea8c5372023-07-03 20:31:42 +0000475 description = builder.CreateString(f"Vela {__version__} Optimised")
476 self.nng.metadata.append(("vela_version", __version__))
Tim Hall79d07d22020-04-27 18:20:16 +0100477
Johan Alfvén673683b2022-09-05 09:39:47 +0200478 subgraph_offset = self.write_offset_vector(
479 [self.serialise_subgraph(sg, builder.CreateString(sg.name)) for sg in self.subgraphs_to_write]
480 )
Tim Hall79d07d22020-04-27 18:20:16 +0100481
482 # Fill the metadata buffer
483 version = np.int32(0)
Johan Alfvén673683b2022-09-05 09:39:47 +0200484 subgraph_idx = np.int32(len(self.subgraphs_to_write))
485
486 nbr_tensors_all = np.sum([len(tensor_map_sg) for tensor_map_sg in self.tensor_map_all], dtype=np.int32)
487
488 offlineAlloc = [version, subgraph_idx, nbr_tensors_all]
Tim Hall79d07d22020-04-27 18:20:16 +0100489
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100490 if not any([name == b"OfflineMemoryAllocation" for name, _ in self.nng.metadata]):
Johan Alfvén673683b2022-09-05 09:39:47 +0200491 for tensor_map_sg in self.tensor_map_all:
492 nbr_tensors_sg = np.int32(len(tensor_map_sg))
493 # An offset of -1 indicates that the tensor will be allocated online by Tensorflow Lite Micro
494 offsets = [np.int32(-1)] * nbr_tensors_sg
495 # Ensure that the order of the offsets match the order of the tensors
496 for tens, idx in tensor_map_sg.items():
497 # Set offsets for tensor allocated in Tensor Arena or in the scratch_fast area
498 if tens.mem_type in (MemType.Scratch, MemType.Scratch_fast):
499 offsets[idx] = np.int32(tens.address) if tens.address is not None else np.int32(0)
Tim Hall79d07d22020-04-27 18:20:16 +0100500
Johan Alfvén673683b2022-09-05 09:39:47 +0200501 offlineAlloc += offsets
Tim Hall79d07d22020-04-27 18:20:16 +0100502
Johan Alfvén673683b2022-09-05 09:39:47 +0200503 self.nng.metadata.append(("OfflineMemoryAllocation", np.array(offlineAlloc)))
Michael McGeagh22f74e12020-08-07 16:21:03 +0100504
505 metadata_list = []
506 for name, buffer in self.nng.metadata:
507 self.buffers_to_write.append(buffer)
508 metadata_list.append((name, len(self.buffers_to_write) - 1))
Tim Hall79d07d22020-04-27 18:20:16 +0100509
510 buffers_offset = self.write_offset_vector([self.serialise_buffer(buf) for buf in self.buffers_to_write])
Tim Hall79d07d22020-04-27 18:20:16 +0100511 metadata_offset = self.write_offset_vector([self.serialise_metadata(metadata) for metadata in metadata_list])
512
513 Model.ModelStart(builder)
514 Model.ModelAddVersion(builder, tflite_version)
515 Model.ModelAddOperatorCodes(builder, operator_code_offset)
516 Model.ModelAddSubgraphs(builder, subgraph_offset)
517 Model.ModelAddDescription(builder, description)
518 Model.ModelAddBuffers(builder, buffers_offset)
519 Model.ModelAddMetadata(builder, metadata_offset)
520 return Model.ModelEnd(builder)
521
522 def serialise(self):
523
524 model = self.serialise_model()
525
526 self.builder.FinishWithFileIdentifier(model, tflite_file_identifier)
527
528 return self.builder.Output()
529
530 def write(self, filename):
531 with open(self.filename, "wb") as f:
532 f.write(self.serialised_buf)
533
534
535def write_tflite(nng, filename):
536 writer = TFLiteSerialiser(nng)
537 buf = writer.serialise()
538
539 with open(filename, "wb") as f:
540 f.write(buf)