blob: 3632f82168ec992474195670c1d7211753dfc70b [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Functions used to read from a TensorFlow Lite format file.
Diego Russoea6111a2020-04-14 18:41:58 +010018import os.path
Tim Hall79d07d22020-04-27 18:20:16 +010019
20import numpy as np
Tim Hall79d07d22020-04-27 18:20:16 +010021
Diego Russoe8a10452020-04-21 17:39:10 +010022from .nn_graph import Graph
23from .nn_graph import Subgraph
Diego Russoea6111a2020-04-14 18:41:58 +010024from .operation import Operation
Diego Russoe8a10452020-04-21 17:39:10 +010025from .tensor import QuantizationParameters
26from .tensor import Tensor
27from .tflite.BuiltinOperator import BuiltinOperator
28from .tflite.Model import Model
29from .tflite_mapping import builtin_operator_map
30from .tflite_mapping import DataType
31from .tflite_mapping import datatype_map
32from .tflite_mapping import datatype_map_numpy
Tim Hall79d07d22020-04-27 18:20:16 +010033
34
35def decode_str(s):
36 if s is None:
37 return ""
38 return s.decode("utf-8")
39
40
41def reshape_tensor_add_const_op(tens, reorder):
42 if not tens.reshaped:
43 original_shape = tens.shape
44 tens.name = tens.name + "_reshape"
45 tens.shape = [original_shape[idx] for idx in reorder]
46 tens.bandwidth_shape = tens.shape
47 tens.storage_shape = tens.shape
48
49 if tens.values is not None:
50 tens.values = tens.values.transpose(reorder)
51
52 if tens.quant_values is not None:
53 tens.quant_values = tens.quant_values.transpose(reorder)
54
55 op = Operation("Const", tens.name)
56 op.outputs = [tens]
57 tens.ops = [op]
58 tens.reshaped = True
59
60
61class TFLiteSubgraph:
62 def __init__(self, graph, subgraph):
63 self.graph = graph
64 self.name = decode_str(subgraph.Name())
65
66 self.tensors = []
67 for idx in range(subgraph.TensorsLength()):
68 self.tensors.append(self.parse_tensor(subgraph.Tensors(idx)))
69
70 for idx in range(subgraph.OperatorsLength()):
71 self.parse_operator(subgraph.Operators(idx))
72
73 self.outputs = [self.tensors[idx] for idx in subgraph.OutputsAsNumpy()]
74 self.inputs = [self.tensors[idx] for idx in subgraph.InputsAsNumpy()]
75
76 # Fix up tensors without operations. Generate either Placeholder or Constant ops
77 for tens in self.inputs:
78 assert not tens.ops
79 op = Operation("Placeholder", tens.name)
80 op.outputs = [tens]
81 tens.ops = [op]
82
83 for tens in self.tensors:
84 if not tens.ops:
85 op = Operation("Const", tens.name)
86 op.outputs = [tens]
87 tens.ops = [op]
88
89 def parse_tensor(self, tens_data):
90 np_shape = tens_data.ShapeAsNumpy()
91 shape = list(np_shape) if type(np_shape) is np.ndarray else []
92 name = decode_str(tens_data.Name())
93 dtype = datatype_map[tens_data.Type()]
94
95 tens = Tensor(shape, dtype, name)
96
97 quant = tens_data.Quantization()
98
99 def len1_array_to_scalar(arr):
100 # The following flatbuffer quantisation fields all return a scalar value of 0 if they are not definied in
101 # the input buffer. This is represented in Vela by using None.
102 # Otherwise, the fields returned are a single or multi-element array. In which case, single element arrays
103 # are converted to scalars
104 if isinstance(arr, int) and arr == 0:
105 return None
106 if len(arr) == 1:
107 return arr[0]
108 return arr
109
110 tens.quantization = QuantizationParameters()
111 tens.quantization.min = len1_array_to_scalar(quant.MinAsNumpy())
112 tens.quantization.max = len1_array_to_scalar(quant.MaxAsNumpy())
113 tens.quantization.scale_f32 = len1_array_to_scalar(quant.ScaleAsNumpy())
114 tens.quantization.zero_point = len1_array_to_scalar(quant.ZeroPointAsNumpy())
115
116 if dtype == DataType.uint8:
117 tens.quantization.quant_min = 0
118 tens.quantization.quant_max = (1 << dtype.bits) - 1
119 elif dtype in set((DataType.int8, DataType.int16, DataType.int32, DataType.int64)):
120 tens.quantization.quant_min = -(1 << (dtype.bits - 1))
121 tens.quantization.quant_max = (1 << (dtype.bits - 1)) - 1
122 else:
123 raise Exception("DataType '" + str(dtype) + "' is not supported for quantization.")
124
125 if tens.quantization.scale_f32 is None and tens.quantization.zero_point is None:
126 tens.quantization = None
127
128 tens.values = None
129 buf = self.graph.buffers[tens_data.Buffer()]
130 if buf is not None:
131 tens.values = np.array(buf.view(datatype_map_numpy[tens_data.Type()]).reshape(shape))
132 if tens.quantization is not None:
133 tens.quant_values = tens.values
134 tens.values = tens.quantization.dequantize(tens.quant_values)
135 return tens
136
137 def parse_operator(self, op_data):
138 op_type, opt_serializer = self.graph.operator_codes[op_data.OpcodeIndex()]
139 inputs = [self.tensors[idx] for idx in op_data.InputsAsNumpy()]
140 outputs = [self.tensors[idx] for idx in op_data.OutputsAsNumpy()]
141 name = "unknown_op_name"
142 if len(outputs):
143 name = outputs[0].name
144 op = Operation(op_type, name)
145 op.inputs = inputs
146 op.outputs = outputs
147 for out in op.outputs:
148 out.ops = [op]
149
150 activation_function_to_split_out = None
151
152 if op_type.startswith("DepthwiseConv2d") or op_type.startswith("Conv2D"):
153 reshape_tensor_add_const_op(inputs[1], (1, 2, 3, 0))
154
155 if op_type.startswith("FullyConnected"):
156 reshape_tensor_add_const_op(inputs[1], (1, 0))
157
158 if opt_serializer is not None:
159 op.attrs = opt_serializer.deserialize(op_data.BuiltinOptions(), op_data.CustomOptionsAsNumpy())
160
Dwight Lidman3ec04ac2020-04-30 11:54:48 +0200161 if op_type.startswith("ResizeBilinear"):
162 upscaled_shape = [op.inputs[0].shape[1] * 2, op.inputs[0].shape[2] * 2]
163 out_shape = op.outputs[0].shape[1:3]
164 if not op.attrs['align_corners'] and out_shape == upscaled_shape:
165 # this means the output is supposed to be a x2 upscale,
166 # so we need to do SAME padding
167 op.attrs.update({'padding': b'SAME'})
168 elif (op.attrs['align_corners']
169 and out_shape == [upscaled_shape[0] - 1, upscaled_shape[1] - 1]):
170 # here we can just run the avg pool without padding and
171 # produce a (M * 2 - 1, N * 2 - 1) sized output
172 op.attrs.update({'padding': b'VALID'})
173 else:
174 assert False, "Only 2x upscaling is supported"
175 op.attrs.update({'filter_width': 2, 'filter_height': 2, 'stride_w': 1, 'stride_h': 1,})
176
Tim Hall79d07d22020-04-27 18:20:16 +0100177 if "stride_w" in op.attrs:
178 op.attrs["strides"] = (1, op.attrs["stride_h"], op.attrs["stride_w"], 1)
179 if "filter_width" in op.attrs:
180 op.attrs["ksize"] = (1, op.attrs["filter_height"], op.attrs["filter_width"], 1)
181 if "dilation_w_factor" in op.attrs:
182 op.attrs["dilation"] = (1, op.attrs["dilation_h_factor"], op.attrs["dilation_w_factor"], 1)
183 if "depth_multiplier" in op.attrs:
184 op.attrs["channel_multiplier"] = op.attrs["depth_multiplier"]
185
186 if "fused_activation_function" in op.attrs:
187 if op_type in set(("ConcatTFLite",)):
188 act = op.attrs["fused_activation_function"]
189 del op.attrs["fused_activation_function"]
190 if act is not None:
191 activation_function_to_split_out = act
192
193 if activation_function_to_split_out is not None:
194 act_op = Operation(activation_function_to_split_out, name + activation_function_to_split_out)
195 out_tens = op.outputs[0]
196 intermediate_tens = out_tens.clone("_act_intermediate")
197 out_tens.ops = [act_op]
198 act_op.outputs = [out_tens]
199 intermediate_tens.ops = [op]
200 op.outputs[0] = intermediate_tens
201 act_op.inputs = [intermediate_tens]
202
203
204class TFLiteGraph:
205 def __init__(
Diego Russoea6111a2020-04-14 18:41:58 +0100206 self, filename, batch_size=1, feed_dict={}, output_node_names=[], initialisation_nodes=[],
Tim Hall79d07d22020-04-27 18:20:16 +0100207 ):
208
209 self.op_times = {}
210 if batch_size is None:
211 batch_size = 1
212 self.batch_size = batch_size
213 self.name = os.path.splitext(os.path.basename(filename))[0]
214 self.initialisation_nodes = initialisation_nodes
215
216 with open(filename, "rb") as f:
217 buf = bytearray(f.read())
218
219 model = Model.GetRootAsModel(buf, 0)
220
221 self.buffers = []
222 for idx in range(model.BuffersLength()):
223 self.buffers.append(self.parse_buffer(model.Buffers(idx)))
224
225 self.operator_codes = []
226 for idx in range(model.OperatorCodesLength()):
227 self.operator_codes.append(self.parse_operator_code(model.OperatorCodes(idx)))
228
229 self.subgraphs = []
230 for idx in range(model.SubgraphsLength()):
231 self.subgraphs.append(TFLiteSubgraph(self, model.Subgraphs(idx)))
232
233 self.nng = Graph(self.name, self.batch_size)
234 for tflite_sg in self.subgraphs:
235 sg = Subgraph(tflite_sg.name)
236 sg.original_inputs = tflite_sg.inputs # Preserve the original input order
237 sg.output_tensors = tflite_sg.outputs
238 self.nng.subgraphs.append(sg)
239
240 def parse_buffer(self, buf_data):
241 if buf_data.DataLength() == 0:
242 return None
243 data = buf_data.DataAsNumpy()
244 return data
245
246 def parse_operator_code(self, code):
247 c = code.BuiltinCode()
248 op_type, ser = builtin_operator_map[c]
249 if c == BuiltinOperator.CUSTOM:
250 op_type += decode_str(code.CustomCode())
251 return op_type, ser
252
253
254def read_tflite(
Diego Russoea6111a2020-04-14 18:41:58 +0100255 filename, batch_size=1, feed_dict={}, output_node_names=[], initialisation_nodes=[],
Tim Hall79d07d22020-04-27 18:20:16 +0100256):
Diego Russoea6111a2020-04-14 18:41:58 +0100257 tflite_graph = TFLiteGraph(filename, batch_size, feed_dict, output_node_names, initialisation_nodes)
Tim Hall79d07d22020-04-27 18:20:16 +0100258 nng = tflite_graph.nng
259 nng.refresh_after_modification()
260 return nng