blob: 687b8876ac03014a01b3a054c7e97a6b69633544 [file] [log] [blame]
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Functions used to write to a TensorFlow Lite format file. Supports adding in file identifiers.
Tim Hall79d07d22020-04-27 18:20:16 +010018import flatbuffers
Diego Russoe8a10452020-04-21 17:39:10 +010019import flatbuffers.number_types as N
20import numpy as np
21from flatbuffers import encode
Diego Russoea6111a2020-04-14 18:41:58 +010022from flatbuffers.builder import UOffsetTFlags
23
Michael McGeagh7a6f8432020-12-02 15:29:22 +000024from .errors import VelaError
Diego Russoe8a10452020-04-21 17:39:10 +010025from .nn_graph import PassPlacement
Louis Verhaardaee5d752020-09-30 09:01:52 +020026from .operation import Op
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020027from .tensor import MemType
Tim Hall79d07d22020-04-27 18:20:16 +010028from .tflite import Buffer
29from .tflite import Metadata
Diego Russoe8a10452020-04-21 17:39:10 +010030from .tflite import Model
31from .tflite import Operator
32from .tflite import OperatorCode
33from .tflite import QuantizationParameters
34from .tflite import SubGraph
35from .tflite import Tensor
36from .tflite_mapping import builtin_operator_inv_map
37from .tflite_mapping import BuiltinOperator
Diego Russoe8a10452020-04-21 17:39:10 +010038from .tflite_mapping import datatype_inv_map
39
40# ugh, the python flatbuffer interface is missing a method to add in file identifier. patching it in here:
Tim Hall79d07d22020-04-27 18:20:16 +010041
42tflite_version = 3
43tflite_file_identifier = "TFL" + str(tflite_version)
44
45
Tim Hall79d07d22020-04-27 18:20:16 +010046def FinishWithFileIdentifier(self, rootTable, fid):
47 if fid is None or len(fid) != 4:
Michael McGeagh7a6f8432020-12-02 15:29:22 +000048 raise VelaError("FileIdentifier must be 4 chars")
Tim Hall79d07d22020-04-27 18:20:16 +010049
50 flags = N.Uint8Flags
51 prepSize = 4
52 self.Prep(self.minalign, prepSize + len(fid))
53 for i in range(3, -1, -1):
54 self.head = self.head - flags.bytewidth
55 encode.Write(flags.packer_type, self.Bytes, self.Head(), ord(fid[i]))
56
57 return self.Finish(rootTable)
58
59
60flatbuffers.Builder.FinishWithFileIdentifier = FinishWithFileIdentifier
61
62
63def make_vector(v):
64 try:
65 len(v)
66 return v
67 except TypeError:
68 return [v]
69
70
71class TFLiteSerialiser:
72 def __init__(self, nng):
73 self.builder = flatbuffers.Builder(0)
74 self.nng = nng
75
76 self.scratch_buf_id = 0 # Always assign scratch to buffer 0
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020077 self.scratch_fast_buf_id = 1 # Always assign scratch_fast to buffer 1
Tim Hall79d07d22020-04-27 18:20:16 +010078 self.buffers_to_write = [] # have an empty array there
79
Michael McGeaghf3e3ad72020-12-02 12:39:03 +000080 self.ops_to_ignore = (Op.Const, Op.Placeholder, Op.SubgraphInput)
Tim Hall79d07d22020-04-27 18:20:16 +010081
82 self.tensors_to_reshape = {}
83
84 self.subgraphs_to_write = [sg for sg in self.nng.subgraphs if sg.placement == PassPlacement.Cpu]
85
86 all_ops = []
87 for sg in self.subgraphs_to_write:
88 for ps in sg.passes:
89 for op in ps.ops:
90 if op.type not in self.ops_to_ignore:
91 all_ops.append(op)
Louis Verhaardaee5d752020-09-30 09:01:52 +020092 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
Andreas Nevalainend8c032d2020-09-11 10:25:09 +020093 # If values are None op has non-constant weights
94 if op.inputs[1].values is not None:
95 self.tensors_to_reshape[op.inputs[1]] = (3, 0, 1, 2)
Louis Verhaardaee5d752020-09-30 09:01:52 +020096 if op.type == Op.FullyConnected:
Andreas Nevalainend8c032d2020-09-11 10:25:09 +020097 # If values are None op has non-constant weights
98 if op.inputs[1].values is not None:
99 self.tensors_to_reshape[op.inputs[1]] = (1, 0)
Tim Hall79d07d22020-04-27 18:20:16 +0100100
Louis Verhaardaee5d752020-09-30 09:01:52 +0200101 # list of tuple(Op, string); the custom code is only used for 3rd party custom operators
102 self.operator_codes = sorted(set((op.type, op.attrs.get("custom_code", "")) for op in all_ops))
Tim Hall79d07d22020-04-27 18:20:16 +0100103 self.operator_code_map = {}
104
105 def write_byte_vector(self, v, alignment=1):
106 builder = self.builder
107 builder.StartVector(1, len(v), alignment)
108 for e in v[::-1]:
109 builder.PrependByte(e)
110 return builder.EndVector(len(v))
111
112 def write_int_vector(self, v):
113 builder = self.builder
114 builder.StartVector(4, len(v), 4)
115 for e in v[::-1]:
116 builder.PrependInt32(e)
117 return builder.EndVector(len(v))
118
119 def write_long_vector(self, v):
120 builder = self.builder
121 builder.StartVector(8, len(v), 8)
122 for e in v[::-1]:
123 builder.PrependInt64(e)
124 return builder.EndVector(len(v))
125
126 def write_float_vector(self, v):
127 builder = self.builder
128 builder.StartVector(4, len(v), 4)
129 for e in v[::-1]:
130 builder.PrependFloat32(e)
131 return builder.EndVector(len(v))
132
133 def write_offset_vector(self, v):
134 builder = self.builder
135 builder.StartVector(4, len(v), 4)
136 for e in v[::-1]:
137 builder.PrependUOffsetTRelative(e)
138 return builder.EndVector(len(v))
139
Tim Hallc8310b12020-06-17 14:53:11 +0100140 def assign_buffers_to_tensors(self, tensors, scratch_tensor):
141 if scratch_tensor is not None:
142 scratch_tensor_mem_area = scratch_tensor.mem_area
Tim Hall25f605c2020-05-18 18:04:26 +0100143 else:
144 scratch_tensor_mem_area = None # all tensors are initialised to MemArea.Unknown
145
Tim Hall79d07d22020-04-27 18:20:16 +0100146 buffer_map = {}
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200147
Patrik Gustavsson3ab94522020-06-29 17:36:55 +0200148 buf_idx = 2
Tim Hall79d07d22020-04-27 18:20:16 +0100149
150 for tens in tensors:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200151 # Set buffer ids depending on allocation
152 if tens.is_allocated_in_tensor_arena(scratch_tensor_mem_area):
Tim Hall79d07d22020-04-27 18:20:16 +0100153 buffer_map[tens] = self.scratch_buf_id
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200154 elif tens.mem_type == MemType.Scratch_fast:
155 # For Scratch_fast when not co-allocated with scratch in the TensorArena:
156 buffer_map[tens] = self.scratch_fast_buf_id
Tim Hall79d07d22020-04-27 18:20:16 +0100157 else:
158 buffer_map[tens] = buf_idx
159 buf_idx += 1
160
Tim Hallc8310b12020-06-17 14:53:11 +0100161 # Initialize buffers_to_write to a length equal to number of buffers so
Tim Hall79d07d22020-04-27 18:20:16 +0100162 # they can be appended at the correct index during tensor serialization
163 self.buffers_to_write = [None] * (buf_idx)
164
165 return buffer_map
166
Louis Verhaardaee5d752020-09-30 09:01:52 +0200167 def serialise_operator_code(self, idx, op_type, custom_code):
Tim Hall79d07d22020-04-27 18:20:16 +0100168 builder = self.builder
169 custom_code_offset = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200170 if op_type == Op.Custom:
171 tf_code, opt_serializer = builtin_operator_inv_map[op_type]
172 custom_code_offset = builder.CreateString(custom_code)
Tim Hall79d07d22020-04-27 18:20:16 +0100173 else:
Tim Halle9194df2020-08-04 20:37:01 +0100174 assert (
Louis Verhaardaee5d752020-09-30 09:01:52 +0200175 op_type in builtin_operator_inv_map
176 ), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(op_type)
177 tf_code, opt_serializer = builtin_operator_inv_map[op_type]
Tim Hall79d07d22020-04-27 18:20:16 +0100178
Tim Hallb2183762021-01-25 21:42:56 +0000179 if op_type == Op.CustomNpuOp:
Tim Halle9194df2020-08-04 20:37:01 +0100180 assert (
Tim Hallb2183762021-01-25 21:42:56 +0000181 tf_code == BuiltinOperator.CUSTOM
Tim Halle9194df2020-08-04 20:37:01 +0100182 ), "Vela only supports serialising NpuOp operators as TensorFlow Lite Custom operators"
Tim Hall79d07d22020-04-27 18:20:16 +0100183 custom_code_offset = builder.CreateString("ethos-u")
184
Tim Hallb2183762021-01-25 21:42:56 +0000185 # there can be multiple different types of 3rd party custom operators (i.e. non-"ethos-u" ones). therefore we
186 # need to add an extra level of indirection to this particular entry in the operator_code_map to allow for the
187 # correct lookup later on
188 if op_type == Op.Custom:
189 if op_type not in self.operator_code_map:
190 self.operator_code_map[op_type] = {}
191 self.operator_code_map[op_type][custom_code] = (idx, tf_code, opt_serializer)
192 else:
193 self.operator_code_map[op_type] = (idx, tf_code, opt_serializer)
Tim Hall79d07d22020-04-27 18:20:16 +0100194
195 OperatorCode.OperatorCodeStart(builder)
Tim Hall42abec12021-02-04 21:31:57 +0000196 OperatorCode.OperatorCodeAddDeprecatedBuiltinCode(builder, tf_code if tf_code < 127 else 127)
Tim Hall79d07d22020-04-27 18:20:16 +0100197 OperatorCode.OperatorCodeAddBuiltinCode(builder, tf_code)
198 if custom_code_offset is not None:
199 OperatorCode.OperatorCodeAddCustomCode(builder, custom_code_offset)
200
201 return OperatorCode.OperatorCodeEnd(builder)
202
203 def serialise_quantization_parameters(self, quant):
204 builder = self.builder
205
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100206 qp = None
Tim Hall79d07d22020-04-27 18:20:16 +0100207 min = None
208 max = None
209 scale = None
210 zero_point = None
211 if quant is not None:
212 if quant.min is not None:
213 min = self.write_float_vector(make_vector(quant.min))
214 if quant.max is not None:
215 max = self.write_float_vector(make_vector(quant.max))
216 if quant.scale_f32 is not None:
217 scale = self.write_float_vector(make_vector(quant.scale_f32))
218 if quant.zero_point is not None:
219 zero_point = self.write_long_vector(make_vector(quant.zero_point))
220
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100221 QuantizationParameters.QuantizationParametersStart(builder)
222 if min is not None:
223 QuantizationParameters.QuantizationParametersAddMin(builder, min)
224 if max is not None:
225 QuantizationParameters.QuantizationParametersAddMax(builder, max)
226 if scale is not None:
227 QuantizationParameters.QuantizationParametersAddScale(builder, scale)
228 if zero_point is not None:
229 QuantizationParameters.QuantizationParametersAddZeroPoint(builder, zero_point)
230 qp = QuantizationParameters.QuantizationParametersEnd(builder)
231
232 return qp
Tim Hall79d07d22020-04-27 18:20:16 +0100233
234 def serialise_tensor(self, tens):
235 builder = self.builder
236 tens_shape = tens.shape
237 values = tens.quant_values
238 if values is None:
239 values = tens.values
240
241 if values is None:
242 values = np.empty(shape=(0), dtype=np.uint8)
243
244 if tens in self.tensors_to_reshape:
245 reorder = self.tensors_to_reshape[tens]
246 tens_shape = [tens_shape[idx] for idx in reorder]
247 values = values.transpose(reorder)
248
Tim Hall79d07d22020-04-27 18:20:16 +0100249 buf_id = self.buffer_map[tens]
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200250 self.buffers_to_write[buf_id] = values.flatten().view(np.uint8)
Tim Hall79d07d22020-04-27 18:20:16 +0100251
252 shape = self.write_int_vector(tens_shape)
253
254 name = builder.CreateString(tens.name)
255 quant = self.serialise_quantization_parameters(tens.quantization)
256
257 Tensor.TensorStart(builder)
258 Tensor.TensorAddShape(builder, shape)
259 Tensor.TensorAddType(builder, datatype_inv_map[tens.dtype])
260 # All tensors must have a valid backing buffer, even if it is empty.
261 # Empty buffers should be kept unique for TensorFlow Lite Micro
262 Tensor.TensorAddBuffer(builder, buf_id)
263 Tensor.TensorAddName(builder, name)
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100264 if quant is not None:
265 Tensor.TensorAddQuantization(builder, quant)
266 Tensor.TensorAddIsVariable(builder, tens.is_variable)
Tim Hall79d07d22020-04-27 18:20:16 +0100267
268 res = Tensor.TensorEnd(builder)
269 return res
270
271 def serialise_operator(self, op):
272 builder = self.builder
273
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100274 inputs_offset = self.write_int_vector(
275 [self.tensor_map[tens] if tens in self.tensor_map else -1 for tens in op.inputs]
276 )
Michael McGeaghbb1b09e2020-08-19 11:24:17 +0100277 outputs_offset = self.write_int_vector(
278 [self.tensor_map[tens] for tens in op.outputs if tens in self.tensor_map]
279 )
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100280 intermediates_offset = self.write_int_vector(
281 [self.tensor_map[tens] for tens in op.intermediates if tens in self.tensor_map]
282 )
Tim Hall79d07d22020-04-27 18:20:16 +0100283
Tim Hallb2183762021-01-25 21:42:56 +0000284 if op.type == Op.Custom:
285 op_idx, tflop, opt_serializer = self.operator_code_map[op.type][op.attrs.get("custom_code", "")]
286 else:
287 op_idx, tflop, opt_serializer = self.operator_code_map[op.type]
Tim Hall79d07d22020-04-27 18:20:16 +0100288
289 builtin_opt_offset = None
290 custom_opt_offset = None
291 if opt_serializer is not None:
292 attrs = dict(op.attrs)
293 if "strides" in attrs:
294 attrs["stride_h"] = attrs["strides"][1]
295 attrs["stride_w"] = attrs["strides"][2]
296 if "ksize" in attrs:
297 attrs["filter_height"] = attrs["ksize"][1]
298 attrs["filter_width"] = attrs["ksize"][2]
299 if "dilation" in attrs:
300 attrs["dilation_h_factor"] = attrs["dilation"][1]
301 attrs["dilation_w_factor"] = attrs["dilation"][2]
302 if "channel_multiplier" in attrs:
303 attrs["depth_multiplier"] = attrs["channel_multiplier"]
Louis Verhaardc86a9d22020-11-02 18:04:27 +0100304 attrs["fused_activation_function"] = op.activation.op_type if op.activation is not None else None
Tim Hall79d07d22020-04-27 18:20:16 +0100305
306 builtin_opt_offset, custom_opt_offset = opt_serializer.serialize(builder, attrs)
307
308 mutating_variable_inputs_offset = self.write_byte_vector([])
309 Operator.OperatorStart(builder)
310 Operator.OperatorAddOpcodeIndex(builder, op_idx)
311 Operator.OperatorAddInputs(builder, inputs_offset)
312 Operator.OperatorAddOutputs(builder, outputs_offset)
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100313 Operator.OperatorAddIntermediates(builder, intermediates_offset)
Tim Hall79d07d22020-04-27 18:20:16 +0100314
315 if builtin_opt_offset is not None:
316 Operator.OperatorAddBuiltinOptionsType(builder, opt_serializer.builtin_opt_type)
317 Operator.OperatorAddBuiltinOptions(builder, builtin_opt_offset)
318 if custom_opt_offset is not None:
319 Operator.OperatorAddCustomOptions(builder, custom_opt_offset)
320 Operator.OperatorAddCustomOptionsFormat(builder, opt_serializer.custom_opt_format)
321
322 Operator.OperatorAddMutatingVariableInputs(builder, mutating_variable_inputs_offset)
323 return Operator.OperatorEnd(builder)
324
325 def serialise_subgraph(self, sg):
326 builder = self.builder
327 tensor_set = set()
Tim Hall79d07d22020-04-27 18:20:16 +0100328 all_ops = []
Michael McGeagh515c9562020-09-02 15:52:43 +0100329 placeholder_ops = []
330
Tim Hall79d07d22020-04-27 18:20:16 +0100331 for ps in sg.passes:
332 for op in ps.ops:
333 if op.type not in self.ops_to_ignore:
334 all_ops.append(op)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200335 elif op.type == Op.Placeholder:
Michael McGeagh515c9562020-09-02 15:52:43 +0100336 placeholder_ops.append(op)
Tim Hall79d07d22020-04-27 18:20:16 +0100337
Michael McGeagh515c9562020-09-02 15:52:43 +0100338 # Add the tensors from all valid ops, as well as the tensors from placeholder ops
339 # This allows us to serialise tensors which arent attached to any specific ops,
340 # e.g. due to an empty graph containing no ops
341 for op in all_ops + placeholder_ops:
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100342 for tens in op.inputs + op.outputs + op.intermediates:
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200343 if tens is not None:
344 tensor_set.add(tens)
Tim Hall79d07d22020-04-27 18:20:16 +0100345
346 all_tensors = [tens for nm, idx, tens in sorted((tens.name, idx, tens) for idx, tens in enumerate(tensor_set))]
347
Patrik Gustavsson3ab94522020-06-29 17:36:55 +0200348 scratch_tensors = [tens for tens in all_tensors if tens.name.endswith("scratch")]
349
Tim Hallc8310b12020-06-17 14:53:11 +0100350 if len(scratch_tensors) == 0:
351 scratch_tensor = None
352 else:
353 assert len(scratch_tensors) == 1, "Multiple scratch tensors"
354 scratch_tensor = scratch_tensors[0]
355
Tim Hall79d07d22020-04-27 18:20:16 +0100356 self.tensor_map = {tens: idx for idx, tens in enumerate(all_tensors)}
Tim Hallc8310b12020-06-17 14:53:11 +0100357 self.buffer_map = self.assign_buffers_to_tensors(all_tensors, scratch_tensor)
Tim Hall79d07d22020-04-27 18:20:16 +0100358
359 tensors_offset = self.write_offset_vector([self.serialise_tensor(tens) for tens in all_tensors])
360
Tim Hall79d07d22020-04-27 18:20:16 +0100361 # Make sure the input_tensors haven't been modified
362 assert all(inp in sg.original_inputs for inp in sg.input_tensors)
Michael McGeaghbb1b09e2020-08-19 11:24:17 +0100363 inputs = [self.tensor_map[tens] for tens in sg.original_inputs if tens in self.tensor_map]
Tim Hallc8310b12020-06-17 14:53:11 +0100364
Tim Hallc8310b12020-06-17 14:53:11 +0100365 inputs_offset = self.write_int_vector(inputs)
Michael McGeaghbb1b09e2020-08-19 11:24:17 +0100366 outputs_offset = self.write_int_vector(
367 [self.tensor_map[tens] for tens in sg.output_tensors if tens in self.tensor_map]
368 )
Tim Hall79d07d22020-04-27 18:20:16 +0100369
370 operators_offset = self.write_offset_vector([self.serialise_operator(op) for op in all_ops])
371
372 SubGraph.SubGraphStart(builder)
373 SubGraph.SubGraphAddTensors(builder, tensors_offset)
374 SubGraph.SubGraphAddInputs(builder, inputs_offset)
375 SubGraph.SubGraphAddOutputs(builder, outputs_offset)
376
377 SubGraph.SubGraphAddOperators(builder, operators_offset)
378
379 return SubGraph.SubGraphEnd(builder)
380
381 def write_aligned_bytes(self, buf):
382 builder = self.builder
383 builder.nested = True
384 data = bytes(buf)
385 length_bytes = UOffsetTFlags.py_type(len(data))
386 builder.Prep(16, length_bytes) # Reserve aligned storage
387 builder.head = UOffsetTFlags.py_type(builder.Head() - length_bytes) # Update FlatBuffer internal pointer
388 builder.Bytes[builder.Head() : builder.Head() + length_bytes] = data # Assign bytes to aligned area
389 return builder.EndVector(length_bytes)
390
391 def serialise_buffer(self, buf):
392 builder = self.builder
393 data = None
394 if buf is not None:
395 data = self.write_aligned_bytes(buf)
396 Buffer.BufferStart(builder)
397 if data is not None:
398 Buffer.BufferAddData(builder, data)
399 return Buffer.BufferEnd(builder)
400
401 def serialise_metadata(self, metadata):
402 builder = self.builder
403 name = builder.CreateString(metadata[0])
404
405 Metadata.MetadataStart(builder)
406 Metadata.MetadataAddName(builder, name)
407 Metadata.MetadataAddBuffer(builder, metadata[1])
408
409 return Metadata.MetadataEnd(builder)
410
411 def serialise_model(self):
412 builder = self.builder
413 operator_code_offset = self.write_offset_vector(
Louis Verhaardaee5d752020-09-30 09:01:52 +0200414 [self.serialise_operator_code(idx, optype, code) for idx, (optype, code) in enumerate(self.operator_codes)]
Tim Hall79d07d22020-04-27 18:20:16 +0100415 )
416
417 description = builder.CreateString("Vela Optimised")
418
419 subgraph_offset = self.write_offset_vector([self.serialise_subgraph(sg) for sg in self.subgraphs_to_write])
420
421 # Fill the metadata buffer
422 version = np.int32(0)
423 subgraph_idx = np.int32(len(self.subgraphs_to_write)) # Only 1 supported currently
424 nbr_tensors = np.int32(len(self.tensor_map))
425
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100426 if not any([name == b"OfflineMemoryAllocation" for name, _ in self.nng.metadata]):
427 # An offset of -1 indicates that the tensor will be allocated online by Tensorflow Lite Micro
428 offsets = [np.int32(-1)] * nbr_tensors
Tim Hall79d07d22020-04-27 18:20:16 +0100429
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100430 # Ensure that the order of the offsets match the order of the tensors
431 for tens, idx in self.tensor_map.items():
432 # Set offsets for tensor allocated in Tensor Arena or in the scratch_fast area
433 if tens.mem_type in (MemType.Scratch, MemType.Scratch_fast):
434 offsets[idx] = np.int32(tens.address) if tens.address is not None else np.int32(0)
Tim Hall79d07d22020-04-27 18:20:16 +0100435
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100436 self.nng.metadata.append(
437 ("OfflineMemoryAllocation", np.array([version, subgraph_idx, nbr_tensors] + offsets))
438 )
Michael McGeagh22f74e12020-08-07 16:21:03 +0100439
440 metadata_list = []
441 for name, buffer in self.nng.metadata:
442 self.buffers_to_write.append(buffer)
443 metadata_list.append((name, len(self.buffers_to_write) - 1))
Tim Hall79d07d22020-04-27 18:20:16 +0100444
445 buffers_offset = self.write_offset_vector([self.serialise_buffer(buf) for buf in self.buffers_to_write])
Tim Hall79d07d22020-04-27 18:20:16 +0100446 metadata_offset = self.write_offset_vector([self.serialise_metadata(metadata) for metadata in metadata_list])
447
448 Model.ModelStart(builder)
449 Model.ModelAddVersion(builder, tflite_version)
450 Model.ModelAddOperatorCodes(builder, operator_code_offset)
451 Model.ModelAddSubgraphs(builder, subgraph_offset)
452 Model.ModelAddDescription(builder, description)
453 Model.ModelAddBuffers(builder, buffers_offset)
454 Model.ModelAddMetadata(builder, metadata_offset)
455 return Model.ModelEnd(builder)
456
457 def serialise(self):
458
459 model = self.serialise_model()
460
461 self.builder.FinishWithFileIdentifier(model, tflite_file_identifier)
462
463 return self.builder.Output()
464
465 def write(self, filename):
466 with open(self.filename, "wb") as f:
467 f.write(self.serialised_buf)
468
469
470def write_tflite(nng, filename):
471 writer = TFLiteSerialiser(nng)
472 buf = writer.serialise()
473
474 with open(filename, "wb") as f:
475 f.write(buf)