blob: 8cabb0ac5b27dbe0ea5c0c32c76493eeec14c9e3 [file] [log] [blame]
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Functions used to write to a TensorFlow Lite format file. Supports adding in file identifiers.
Tim Hall79d07d22020-04-27 18:20:16 +010018import flatbuffers
Diego Russoe8a10452020-04-21 17:39:10 +010019import flatbuffers.number_types as N
20import numpy as np
21from flatbuffers import encode
Diego Russoea6111a2020-04-14 18:41:58 +010022from flatbuffers.builder import UOffsetTFlags
23
Michael McGeagh7a6f8432020-12-02 15:29:22 +000024from .errors import VelaError
Diego Russoe8a10452020-04-21 17:39:10 +010025from .nn_graph import PassPlacement
Louis Verhaardaee5d752020-09-30 09:01:52 +020026from .operation import Op
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020027from .tensor import MemType
Samuel Panijel6f4955a2021-06-10 13:40:03 +030028from .tensor import TensorPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010029from .tflite import Buffer
30from .tflite import Metadata
Diego Russoe8a10452020-04-21 17:39:10 +010031from .tflite import Model
32from .tflite import Operator
33from .tflite import OperatorCode
34from .tflite import QuantizationParameters
35from .tflite import SubGraph
36from .tflite import Tensor
37from .tflite_mapping import builtin_operator_inv_map
38from .tflite_mapping import BuiltinOperator
Diego Russoe8a10452020-04-21 17:39:10 +010039from .tflite_mapping import datatype_inv_map
40
Samuel Panijel6f4955a2021-06-10 13:40:03 +030041
Diego Russoe8a10452020-04-21 17:39:10 +010042# ugh, the python flatbuffer interface is missing a method to add in file identifier. patching it in here:
Tim Hall79d07d22020-04-27 18:20:16 +010043
44tflite_version = 3
45tflite_file_identifier = "TFL" + str(tflite_version)
46
47
Tim Hall79d07d22020-04-27 18:20:16 +010048def FinishWithFileIdentifier(self, rootTable, fid):
49 if fid is None or len(fid) != 4:
Michael McGeagh7a6f8432020-12-02 15:29:22 +000050 raise VelaError("FileIdentifier must be 4 chars")
Tim Hall79d07d22020-04-27 18:20:16 +010051
52 flags = N.Uint8Flags
53 prepSize = 4
54 self.Prep(self.minalign, prepSize + len(fid))
55 for i in range(3, -1, -1):
56 self.head = self.head - flags.bytewidth
57 encode.Write(flags.packer_type, self.Bytes, self.Head(), ord(fid[i]))
58
59 return self.Finish(rootTable)
60
61
62flatbuffers.Builder.FinishWithFileIdentifier = FinishWithFileIdentifier
63
64
65def make_vector(v):
66 try:
67 len(v)
68 return v
69 except TypeError:
70 return [v]
71
72
73class TFLiteSerialiser:
74 def __init__(self, nng):
75 self.builder = flatbuffers.Builder(0)
76 self.nng = nng
77
78 self.scratch_buf_id = 0 # Always assign scratch to buffer 0
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020079 self.scratch_fast_buf_id = 1 # Always assign scratch_fast to buffer 1
Tim Hall79d07d22020-04-27 18:20:16 +010080 self.buffers_to_write = [] # have an empty array there
81
Michael McGeaghf3e3ad72020-12-02 12:39:03 +000082 self.ops_to_ignore = (Op.Const, Op.Placeholder, Op.SubgraphInput)
Tim Hall79d07d22020-04-27 18:20:16 +010083
84 self.tensors_to_reshape = {}
85
86 self.subgraphs_to_write = [sg for sg in self.nng.subgraphs if sg.placement == PassPlacement.Cpu]
87
88 all_ops = []
89 for sg in self.subgraphs_to_write:
90 for ps in sg.passes:
91 for op in ps.ops:
92 if op.type not in self.ops_to_ignore:
93 all_ops.append(op)
Louis Verhaardaee5d752020-09-30 09:01:52 +020094 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
Andreas Nevalainend8c032d2020-09-11 10:25:09 +020095 # If values are None op has non-constant weights
96 if op.inputs[1].values is not None:
97 self.tensors_to_reshape[op.inputs[1]] = (3, 0, 1, 2)
Louis Verhaardaee5d752020-09-30 09:01:52 +020098 if op.type == Op.FullyConnected:
Andreas Nevalainend8c032d2020-09-11 10:25:09 +020099 # If values are None op has non-constant weights
100 if op.inputs[1].values is not None:
101 self.tensors_to_reshape[op.inputs[1]] = (1, 0)
Tim Hall79d07d22020-04-27 18:20:16 +0100102
Louis Verhaardaee5d752020-09-30 09:01:52 +0200103 # list of tuple(Op, string); the custom code is only used for 3rd party custom operators
104 self.operator_codes = sorted(set((op.type, op.attrs.get("custom_code", "")) for op in all_ops))
Tim Hall79d07d22020-04-27 18:20:16 +0100105 self.operator_code_map = {}
106
107 def write_byte_vector(self, v, alignment=1):
108 builder = self.builder
109 builder.StartVector(1, len(v), alignment)
110 for e in v[::-1]:
111 builder.PrependByte(e)
112 return builder.EndVector(len(v))
113
114 def write_int_vector(self, v):
115 builder = self.builder
116 builder.StartVector(4, len(v), 4)
117 for e in v[::-1]:
118 builder.PrependInt32(e)
119 return builder.EndVector(len(v))
120
121 def write_long_vector(self, v):
122 builder = self.builder
123 builder.StartVector(8, len(v), 8)
124 for e in v[::-1]:
125 builder.PrependInt64(e)
126 return builder.EndVector(len(v))
127
128 def write_float_vector(self, v):
129 builder = self.builder
130 builder.StartVector(4, len(v), 4)
131 for e in v[::-1]:
132 builder.PrependFloat32(e)
133 return builder.EndVector(len(v))
134
135 def write_offset_vector(self, v):
136 builder = self.builder
137 builder.StartVector(4, len(v), 4)
138 for e in v[::-1]:
139 builder.PrependUOffsetTRelative(e)
140 return builder.EndVector(len(v))
141
Tim Hallc8310b12020-06-17 14:53:11 +0100142 def assign_buffers_to_tensors(self, tensors, scratch_tensor):
143 if scratch_tensor is not None:
144 scratch_tensor_mem_area = scratch_tensor.mem_area
Tim Hall25f605c2020-05-18 18:04:26 +0100145 else:
146 scratch_tensor_mem_area = None # all tensors are initialised to MemArea.Unknown
147
Tim Hall79d07d22020-04-27 18:20:16 +0100148 buffer_map = {}
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200149
Patrik Gustavsson3ab94522020-06-29 17:36:55 +0200150 buf_idx = 2
Tim Hall79d07d22020-04-27 18:20:16 +0100151
152 for tens in tensors:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200153 # Set buffer ids depending on allocation
154 if tens.is_allocated_in_tensor_arena(scratch_tensor_mem_area):
Tim Hall79d07d22020-04-27 18:20:16 +0100155 buffer_map[tens] = self.scratch_buf_id
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200156 elif tens.mem_type == MemType.Scratch_fast:
157 # For Scratch_fast when not co-allocated with scratch in the TensorArena:
158 buffer_map[tens] = self.scratch_fast_buf_id
Tim Hall79d07d22020-04-27 18:20:16 +0100159 else:
160 buffer_map[tens] = buf_idx
161 buf_idx += 1
162
Tim Hallc8310b12020-06-17 14:53:11 +0100163 # Initialize buffers_to_write to a length equal to number of buffers so
Tim Hall79d07d22020-04-27 18:20:16 +0100164 # they can be appended at the correct index during tensor serialization
165 self.buffers_to_write = [None] * (buf_idx)
166
167 return buffer_map
168
Louis Verhaardaee5d752020-09-30 09:01:52 +0200169 def serialise_operator_code(self, idx, op_type, custom_code):
Tim Hall79d07d22020-04-27 18:20:16 +0100170 builder = self.builder
171 custom_code_offset = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200172 if op_type == Op.Custom:
173 tf_code, opt_serializer = builtin_operator_inv_map[op_type]
174 custom_code_offset = builder.CreateString(custom_code)
Tim Hall79d07d22020-04-27 18:20:16 +0100175 else:
Tim Halle9194df2020-08-04 20:37:01 +0100176 assert (
Louis Verhaardaee5d752020-09-30 09:01:52 +0200177 op_type in builtin_operator_inv_map
178 ), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(op_type)
179 tf_code, opt_serializer = builtin_operator_inv_map[op_type]
Tim Hall79d07d22020-04-27 18:20:16 +0100180
Tim Hallb2183762021-01-25 21:42:56 +0000181 if op_type == Op.CustomNpuOp:
Tim Halle9194df2020-08-04 20:37:01 +0100182 assert (
Tim Hallb2183762021-01-25 21:42:56 +0000183 tf_code == BuiltinOperator.CUSTOM
Tim Halle9194df2020-08-04 20:37:01 +0100184 ), "Vela only supports serialising NpuOp operators as TensorFlow Lite Custom operators"
Tim Hall79d07d22020-04-27 18:20:16 +0100185 custom_code_offset = builder.CreateString("ethos-u")
186
Tim Hallb2183762021-01-25 21:42:56 +0000187 # there can be multiple different types of 3rd party custom operators (i.e. non-"ethos-u" ones). therefore we
188 # need to add an extra level of indirection to this particular entry in the operator_code_map to allow for the
189 # correct lookup later on
190 if op_type == Op.Custom:
191 if op_type not in self.operator_code_map:
192 self.operator_code_map[op_type] = {}
193 self.operator_code_map[op_type][custom_code] = (idx, tf_code, opt_serializer)
194 else:
195 self.operator_code_map[op_type] = (idx, tf_code, opt_serializer)
Tim Hall79d07d22020-04-27 18:20:16 +0100196
197 OperatorCode.OperatorCodeStart(builder)
Tim Hall42abec12021-02-04 21:31:57 +0000198 OperatorCode.OperatorCodeAddDeprecatedBuiltinCode(builder, tf_code if tf_code < 127 else 127)
Tim Hall79d07d22020-04-27 18:20:16 +0100199 OperatorCode.OperatorCodeAddBuiltinCode(builder, tf_code)
200 if custom_code_offset is not None:
201 OperatorCode.OperatorCodeAddCustomCode(builder, custom_code_offset)
202
203 return OperatorCode.OperatorCodeEnd(builder)
204
205 def serialise_quantization_parameters(self, quant):
206 builder = self.builder
207
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100208 qp = None
Tim Hall79d07d22020-04-27 18:20:16 +0100209 min = None
210 max = None
211 scale = None
212 zero_point = None
213 if quant is not None:
214 if quant.min is not None:
215 min = self.write_float_vector(make_vector(quant.min))
216 if quant.max is not None:
217 max = self.write_float_vector(make_vector(quant.max))
218 if quant.scale_f32 is not None:
219 scale = self.write_float_vector(make_vector(quant.scale_f32))
220 if quant.zero_point is not None:
221 zero_point = self.write_long_vector(make_vector(quant.zero_point))
222
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100223 QuantizationParameters.QuantizationParametersStart(builder)
224 if min is not None:
225 QuantizationParameters.QuantizationParametersAddMin(builder, min)
226 if max is not None:
227 QuantizationParameters.QuantizationParametersAddMax(builder, max)
228 if scale is not None:
229 QuantizationParameters.QuantizationParametersAddScale(builder, scale)
230 if zero_point is not None:
231 QuantizationParameters.QuantizationParametersAddZeroPoint(builder, zero_point)
232 qp = QuantizationParameters.QuantizationParametersEnd(builder)
233
234 return qp
Tim Hall79d07d22020-04-27 18:20:16 +0100235
236 def serialise_tensor(self, tens):
237 builder = self.builder
238 tens_shape = tens.shape
239 values = tens.quant_values
240 if values is None:
241 values = tens.values
242
243 if values is None:
244 values = np.empty(shape=(0), dtype=np.uint8)
245
246 if tens in self.tensors_to_reshape:
247 reorder = self.tensors_to_reshape[tens]
248 tens_shape = [tens_shape[idx] for idx in reorder]
249 values = values.transpose(reorder)
250
Tim Hall79d07d22020-04-27 18:20:16 +0100251 buf_id = self.buffer_map[tens]
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200252 self.buffers_to_write[buf_id] = values.flatten().view(np.uint8)
Tim Hall79d07d22020-04-27 18:20:16 +0100253
254 shape = self.write_int_vector(tens_shape)
255
256 name = builder.CreateString(tens.name)
257 quant = self.serialise_quantization_parameters(tens.quantization)
258
259 Tensor.TensorStart(builder)
260 Tensor.TensorAddShape(builder, shape)
261 Tensor.TensorAddType(builder, datatype_inv_map[tens.dtype])
262 # All tensors must have a valid backing buffer, even if it is empty.
263 # Empty buffers should be kept unique for TensorFlow Lite Micro
264 Tensor.TensorAddBuffer(builder, buf_id)
265 Tensor.TensorAddName(builder, name)
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100266 if quant is not None:
267 Tensor.TensorAddQuantization(builder, quant)
268 Tensor.TensorAddIsVariable(builder, tens.is_variable)
Tim Hall79d07d22020-04-27 18:20:16 +0100269
270 res = Tensor.TensorEnd(builder)
271 return res
272
273 def serialise_operator(self, op):
274 builder = self.builder
275
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100276 inputs_offset = self.write_int_vector(
277 [self.tensor_map[tens] if tens in self.tensor_map else -1 for tens in op.inputs]
278 )
Michael McGeaghbb1b09e2020-08-19 11:24:17 +0100279 outputs_offset = self.write_int_vector(
280 [self.tensor_map[tens] for tens in op.outputs if tens in self.tensor_map]
281 )
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100282 intermediates_offset = self.write_int_vector(
283 [self.tensor_map[tens] for tens in op.intermediates if tens in self.tensor_map]
284 )
Tim Hall79d07d22020-04-27 18:20:16 +0100285
Tim Hallb2183762021-01-25 21:42:56 +0000286 if op.type == Op.Custom:
287 op_idx, tflop, opt_serializer = self.operator_code_map[op.type][op.attrs.get("custom_code", "")]
288 else:
289 op_idx, tflop, opt_serializer = self.operator_code_map[op.type]
Tim Hall79d07d22020-04-27 18:20:16 +0100290
291 builtin_opt_offset = None
292 custom_opt_offset = None
293 if opt_serializer is not None:
294 attrs = dict(op.attrs)
295 if "strides" in attrs:
296 attrs["stride_h"] = attrs["strides"][1]
297 attrs["stride_w"] = attrs["strides"][2]
298 if "ksize" in attrs:
299 attrs["filter_height"] = attrs["ksize"][1]
300 attrs["filter_width"] = attrs["ksize"][2]
301 if "dilation" in attrs:
302 attrs["dilation_h_factor"] = attrs["dilation"][1]
303 attrs["dilation_w_factor"] = attrs["dilation"][2]
304 if "channel_multiplier" in attrs:
305 attrs["depth_multiplier"] = attrs["channel_multiplier"]
Louis Verhaardc86a9d22020-11-02 18:04:27 +0100306 attrs["fused_activation_function"] = op.activation.op_type if op.activation is not None else None
Tim Hall79d07d22020-04-27 18:20:16 +0100307
308 builtin_opt_offset, custom_opt_offset = opt_serializer.serialize(builder, attrs)
309
310 mutating_variable_inputs_offset = self.write_byte_vector([])
311 Operator.OperatorStart(builder)
312 Operator.OperatorAddOpcodeIndex(builder, op_idx)
313 Operator.OperatorAddInputs(builder, inputs_offset)
314 Operator.OperatorAddOutputs(builder, outputs_offset)
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100315 Operator.OperatorAddIntermediates(builder, intermediates_offset)
Tim Hall79d07d22020-04-27 18:20:16 +0100316
317 if builtin_opt_offset is not None:
318 Operator.OperatorAddBuiltinOptionsType(builder, opt_serializer.builtin_opt_type)
319 Operator.OperatorAddBuiltinOptions(builder, builtin_opt_offset)
320 if custom_opt_offset is not None:
321 Operator.OperatorAddCustomOptions(builder, custom_opt_offset)
322 Operator.OperatorAddCustomOptionsFormat(builder, opt_serializer.custom_opt_format)
323
324 Operator.OperatorAddMutatingVariableInputs(builder, mutating_variable_inputs_offset)
325 return Operator.OperatorEnd(builder)
326
327 def serialise_subgraph(self, sg):
328 builder = self.builder
329 tensor_set = set()
Tim Hall79d07d22020-04-27 18:20:16 +0100330 all_ops = []
Michael McGeagh515c9562020-09-02 15:52:43 +0100331 placeholder_ops = []
332
Tim Hall79d07d22020-04-27 18:20:16 +0100333 for ps in sg.passes:
334 for op in ps.ops:
335 if op.type not in self.ops_to_ignore:
336 all_ops.append(op)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200337 elif op.type == Op.Placeholder:
Michael McGeagh515c9562020-09-02 15:52:43 +0100338 placeholder_ops.append(op)
Tim Hall79d07d22020-04-27 18:20:16 +0100339
Michael McGeagh515c9562020-09-02 15:52:43 +0100340 # Add the tensors from all valid ops, as well as the tensors from placeholder ops
341 # This allows us to serialise tensors which arent attached to any specific ops,
342 # e.g. due to an empty graph containing no ops
343 for op in all_ops + placeholder_ops:
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100344 for tens in op.inputs + op.outputs + op.intermediates:
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200345 if tens is not None:
346 tensor_set.add(tens)
Tim Hall79d07d22020-04-27 18:20:16 +0100347
348 all_tensors = [tens for nm, idx, tens in sorted((tens.name, idx, tens) for idx, tens in enumerate(tensor_set))]
349
Samuel Panijel6f4955a2021-06-10 13:40:03 +0300350 scratch_tensors = [tens for tens in all_tensors if tens.purpose is TensorPurpose.Scratch]
Patrik Gustavsson3ab94522020-06-29 17:36:55 +0200351
Tim Hallc8310b12020-06-17 14:53:11 +0100352 if len(scratch_tensors) == 0:
353 scratch_tensor = None
354 else:
355 assert len(scratch_tensors) == 1, "Multiple scratch tensors"
356 scratch_tensor = scratch_tensors[0]
357
Tim Hall79d07d22020-04-27 18:20:16 +0100358 self.tensor_map = {tens: idx for idx, tens in enumerate(all_tensors)}
Tim Hallc8310b12020-06-17 14:53:11 +0100359 self.buffer_map = self.assign_buffers_to_tensors(all_tensors, scratch_tensor)
Tim Hall79d07d22020-04-27 18:20:16 +0100360
361 tensors_offset = self.write_offset_vector([self.serialise_tensor(tens) for tens in all_tensors])
362
Tim Hall79d07d22020-04-27 18:20:16 +0100363 # Make sure the input_tensors haven't been modified
364 assert all(inp in sg.original_inputs for inp in sg.input_tensors)
Michael McGeaghbb1b09e2020-08-19 11:24:17 +0100365 inputs = [self.tensor_map[tens] for tens in sg.original_inputs if tens in self.tensor_map]
Tim Hallc8310b12020-06-17 14:53:11 +0100366
Tim Hallc8310b12020-06-17 14:53:11 +0100367 inputs_offset = self.write_int_vector(inputs)
Michael McGeaghbb1b09e2020-08-19 11:24:17 +0100368 outputs_offset = self.write_int_vector(
369 [self.tensor_map[tens] for tens in sg.output_tensors if tens in self.tensor_map]
370 )
Tim Hall79d07d22020-04-27 18:20:16 +0100371
372 operators_offset = self.write_offset_vector([self.serialise_operator(op) for op in all_ops])
373
374 SubGraph.SubGraphStart(builder)
375 SubGraph.SubGraphAddTensors(builder, tensors_offset)
376 SubGraph.SubGraphAddInputs(builder, inputs_offset)
377 SubGraph.SubGraphAddOutputs(builder, outputs_offset)
378
379 SubGraph.SubGraphAddOperators(builder, operators_offset)
380
381 return SubGraph.SubGraphEnd(builder)
382
383 def write_aligned_bytes(self, buf):
384 builder = self.builder
385 builder.nested = True
386 data = bytes(buf)
387 length_bytes = UOffsetTFlags.py_type(len(data))
388 builder.Prep(16, length_bytes) # Reserve aligned storage
389 builder.head = UOffsetTFlags.py_type(builder.Head() - length_bytes) # Update FlatBuffer internal pointer
390 builder.Bytes[builder.Head() : builder.Head() + length_bytes] = data # Assign bytes to aligned area
391 return builder.EndVector(length_bytes)
392
393 def serialise_buffer(self, buf):
394 builder = self.builder
395 data = None
396 if buf is not None:
397 data = self.write_aligned_bytes(buf)
398 Buffer.BufferStart(builder)
399 if data is not None:
400 Buffer.BufferAddData(builder, data)
401 return Buffer.BufferEnd(builder)
402
403 def serialise_metadata(self, metadata):
404 builder = self.builder
405 name = builder.CreateString(metadata[0])
406
407 Metadata.MetadataStart(builder)
408 Metadata.MetadataAddName(builder, name)
409 Metadata.MetadataAddBuffer(builder, metadata[1])
410
411 return Metadata.MetadataEnd(builder)
412
413 def serialise_model(self):
414 builder = self.builder
415 operator_code_offset = self.write_offset_vector(
Louis Verhaardaee5d752020-09-30 09:01:52 +0200416 [self.serialise_operator_code(idx, optype, code) for idx, (optype, code) in enumerate(self.operator_codes)]
Tim Hall79d07d22020-04-27 18:20:16 +0100417 )
418
419 description = builder.CreateString("Vela Optimised")
420
421 subgraph_offset = self.write_offset_vector([self.serialise_subgraph(sg) for sg in self.subgraphs_to_write])
422
423 # Fill the metadata buffer
424 version = np.int32(0)
425 subgraph_idx = np.int32(len(self.subgraphs_to_write)) # Only 1 supported currently
426 nbr_tensors = np.int32(len(self.tensor_map))
427
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100428 if not any([name == b"OfflineMemoryAllocation" for name, _ in self.nng.metadata]):
429 # An offset of -1 indicates that the tensor will be allocated online by Tensorflow Lite Micro
430 offsets = [np.int32(-1)] * nbr_tensors
Tim Hall79d07d22020-04-27 18:20:16 +0100431
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100432 # Ensure that the order of the offsets match the order of the tensors
433 for tens, idx in self.tensor_map.items():
434 # Set offsets for tensor allocated in Tensor Arena or in the scratch_fast area
435 if tens.mem_type in (MemType.Scratch, MemType.Scratch_fast):
436 offsets[idx] = np.int32(tens.address) if tens.address is not None else np.int32(0)
Tim Hall79d07d22020-04-27 18:20:16 +0100437
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100438 self.nng.metadata.append(
439 ("OfflineMemoryAllocation", np.array([version, subgraph_idx, nbr_tensors] + offsets))
440 )
Michael McGeagh22f74e12020-08-07 16:21:03 +0100441
442 metadata_list = []
443 for name, buffer in self.nng.metadata:
444 self.buffers_to_write.append(buffer)
445 metadata_list.append((name, len(self.buffers_to_write) - 1))
Tim Hall79d07d22020-04-27 18:20:16 +0100446
447 buffers_offset = self.write_offset_vector([self.serialise_buffer(buf) for buf in self.buffers_to_write])
Tim Hall79d07d22020-04-27 18:20:16 +0100448 metadata_offset = self.write_offset_vector([self.serialise_metadata(metadata) for metadata in metadata_list])
449
450 Model.ModelStart(builder)
451 Model.ModelAddVersion(builder, tflite_version)
452 Model.ModelAddOperatorCodes(builder, operator_code_offset)
453 Model.ModelAddSubgraphs(builder, subgraph_offset)
454 Model.ModelAddDescription(builder, description)
455 Model.ModelAddBuffers(builder, buffers_offset)
456 Model.ModelAddMetadata(builder, metadata_offset)
457 return Model.ModelEnd(builder)
458
459 def serialise(self):
460
461 model = self.serialise_model()
462
463 self.builder.FinishWithFileIdentifier(model, tflite_file_identifier)
464
465 return self.builder.Output()
466
467 def write(self, filename):
468 with open(self.filename, "wb") as f:
469 f.write(self.serialised_buf)
470
471
472def write_tflite(nng, filename):
473 writer = TFLiteSerialiser(nng)
474 buf = writer.serialise()
475
476 with open(filename, "wb") as f:
477 f.write(buf)