blob: 56e0b1cb37938b0cb6cb744a301924cd261224e1 [file] [log] [blame]
Rickard Bolinbc6ee582022-11-04 08:24:29 +00001# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020017# Description:
18# Functions used to read from a TOSA format file.
19import os.path
20import struct
21import sys
22
23import numpy as np
24
25from .nn_graph import Graph
26from .nn_graph import Subgraph
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +020027from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020028from .operation import Op
29from .operation import Operation
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +020030from .reader_util import align_tensor_indices_to_nng
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020031from .reader_util import clone_and_reshape_tensor
32from .reader_util import decode_str
33from .reader_util import fixup_tensors
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020034from .shape4d import Shape4D
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020035from .tensor import QuantizationParameters
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020036from .tensor import shape_num_elements
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020037from .tensor import Tensor
38from .tflite_mapping import DataType
Patrik Gustavssondf995102021-08-23 15:33:59 +020039from .tosa.Op import Op as TosaOp
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020040from .tosa.TosaGraph import TosaGraph as TG
41from .tosa_mapping import datatype_map
Patrik Gustavssond15866c2021-08-10 13:56:34 +020042from .tosa_mapping import datatype_map_numpy
Patrik Gustavssondf995102021-08-23 15:33:59 +020043from .tosa_mapping import TOSA_IFM_INDICES
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020044from .tosa_mapping import tosa_operator_map
45from .tosa_mapping import unsupported_tosa_operators
46
47
48class TosaSubgraph:
Patrik Gustavssond15866c2021-08-10 13:56:34 +020049 def __init__(self, graph, block):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020050 self.graph = graph
51 self.name = decode_str(block.Name())
52
53 self.tensors = []
54 for idx in range(block.TensorsLength()):
Patrik Gustavssond15866c2021-08-10 13:56:34 +020055 self.tensors.append(self.parse_tensor(block.Tensors(idx)))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020056
57 for idx in range(block.OperatorsLength()):
58 self.parse_operator(idx, block.Operators(idx))
59
60 # Get the subgraph inputs and outputs
61 self.inputs = self.get_sg_inputs_remove_duplicates(block)
62 self.outputs = self.get_sg_outputs_remove_duplicates(block)
63 fixup_tensors(self.inputs, self.tensors)
64
65 def get_sg_inputs_remove_duplicates(self, block):
66 inputs = []
67 for idx in range(block.InputsLength()):
68 tens_data = block.Inputs(idx)
69 self.add_not_duplicate(tens_data, inputs, "input")
70 return inputs
71
72 def get_sg_outputs_remove_duplicates(self, block):
73 outputs = []
74 for idx in range(block.OutputsLength()):
75 tens_data = block.Outputs(idx)
76 self.add_not_duplicate(tens_data, outputs, "output")
77 return outputs
78
79 def add_not_duplicate(self, tens_data, tensors, warning_str):
80 name = decode_str(tens_data)
81 tensor = self.get_tensor_by_name(name)
82 if tensor not in tensors:
83 tensors.append(tensor)
84 else:
85 print(f"Warning: Subgraph {warning_str} tensor ({tensor}) already seen. Removing the duplicate.")
86
87 def get_tensor_by_name(self, name):
88 for tens in self.tensors:
89 if tens.name == name:
90 return tens
91 return None
92
93 def parse_operator(self, op_index, op_data):
94 op_code = op_data.Op()
95 if op_code in unsupported_tosa_operators:
96 print("Unsupported Operator", op_code)
Patrik Gustavssondf995102021-08-23 15:33:59 +020097 return
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020098
99 op_type, attr_serializer, quant_serializer, indices = tosa_operator_map[op_code]
100 inputs = []
101 outputs = []
102 for idx in range(op_data.InputsLength()):
103 input_tens = self.get_tensor_by_name(decode_str(op_data.Inputs(idx)))
104 inputs.append(input_tens)
105 assert input_tens is not None
106
107 for idx in range(op_data.OutputsLength()):
108 output_tens = self.get_tensor_by_name(decode_str(op_data.Outputs(idx)))
109 outputs.append(output_tens)
110 assert output_tens is not None
111
Patrik Gustavssondf995102021-08-23 15:33:59 +0200112 # Permutation attribute for TRANSPOSE is an input tensor in TOSA
113 # TODO In order to optimise Depthwise spawning from TFLite Support for removing
114 # Transpose of constant data.
115 # Moving permutation to an attribute, to match internal graph representation for now
116 perms = None
117 if op_code == TosaOp.TRANSPOSE:
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200118 perms = inputs.pop(1)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200119 indices = TOSA_IFM_INDICES
120
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200121 name = "unknown_op_name"
122 if len(outputs):
123 name = outputs[0].name
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200124 inputs = align_tensor_indices_to_nng(op_type, indices, inputs)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200125 op = Operation(op_type, name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200126 op.op_index = op_index
127 op.inputs = inputs
128 op.outputs = outputs
129
130 for out in op.outputs:
131 out.ops = [op]
132
133 # TODO Transpose_conv and conv3d
134 if op.type.is_depthwise_conv2d_op() or op.type.is_conv2d_op() or op.type == Op.FullyConnected:
135 if inputs[1].values is not None:
136 if op.type == Op.FullyConnected:
137 inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0), False)
138 elif op.type.is_conv2d_op():
139 inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0), False)
140 elif op.type.is_depthwise_conv2d_op():
141 inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 0, 3), False)
142 if op.type.needs_bias() and len(inputs) <= op_type.info.indices.biases[0]:
143 # No Bias tensor
144 inputs.append(None)
145 if inputs[-1] and inputs[-1].values is not None:
146 # Since bias tensor is used for both bias and scale,
147 # a clone with a unique equivalence_id is needed
148 inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,), True)
149
150 if attr_serializer is not None:
151 op.attrs = attr_serializer.deserialize(op_data)
152
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200153 if "padding" in op.attrs:
154 padding = op.attrs["padding"] # [top, bottom, left, right]
155 op.attrs["explicit_padding"] = (
156 padding[0],
157 padding[2],
158 padding[1],
159 padding[3],
160 ) # [top, left, bottom, right]
161 if "stride" in op.attrs:
162 stride = op.attrs["stride"]
163 if len(stride) == 2:
164 op.attrs["strides"] = (1, stride[0], stride[1], 1)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200165 del op.attrs["stride"]
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200166 else:
167 # TODO CONV3D more to be done....
168 print("Unsupported kernel dimensions: ", len(stride))
169 assert False
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200170 if "dilation" in op.attrs:
171 dilation = op.attrs["dilation"]
172 if len(dilation) == 2:
173 op.attrs["dilation"] = (1, dilation[0], dilation[1], 1)
174 elif len(dilation) == 3:
175 # TODO CONV3D more to be done....
176 op.attrs["dilation"] = (dilation[0], dilation[1], dilation[2], 1)
177 if "kernel" in op.attrs:
178 kernel = op.attrs["kernel"]
179 if len(kernel) == 2:
180 op.attrs["ksize"] = (1, kernel[0], kernel[1], 1)
181 else:
182 # TODO CONV3D more to be done....
183 print("Unsupported kernel dimensions: ", len(kernel))
184 assert False
Patrik Gustavssonb081d672021-08-25 13:49:25 +0200185 if "shift" in op.attrs and op.type == Op.Mul:
186 shift = op.attrs["shift"]
187 if shift != 0:
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200188 op.explicit_scaling = ExplicitScaling(False, [shift], [1])
Patrik Gustavssondf995102021-08-23 15:33:59 +0200189 if op.type.is_depthwise_conv2d_op():
190 op.attrs["depth_multiplier"] = op.weights.shape[3]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200191 if op.type == Op.SplitSliceRead:
192 op.read_offsets[0] = Shape4D.from_list(list(op.attrs["begin"]), 0)
193 op.read_shapes[0] = op.attrs["size"]
Patrik Gustavssondf995102021-08-23 15:33:59 +0200194
195 elif op.type == Op.Transpose:
196 op.attrs["perms"] = perms.values
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200197
198 if quant_serializer is not None:
199 quant_info = quant_serializer.deserialize(op_data)
200
201 # TODO tensor zero points currently set here
202 # zero points part of Rescale operation, handled in tosa_graph_optimizer
203 if "input_zp" in quant_info:
204 self.set_tensor_zp(op.ifm, quant_info["input_zp"])
205 if "weight_zp" in quant_info:
206 self.set_tensor_zp(op.weights, quant_info["weight_zp"])
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200207 if "output_zp" in quant_info:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200208 self.set_tensor_zp(op.ofm, quant_info["output_zp"])
209 if "a_zp" in quant_info:
210 self.set_tensor_zp(op.ifm, quant_info["a_zp"])
211 if "b_zp" in quant_info:
212 self.set_tensor_zp(op.ifm2, quant_info["b_zp"])
213
Patrik Gustavssond15866c2021-08-10 13:56:34 +0200214 def parse_tensor(self, tens_data):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200215 name = decode_str(tens_data.Name())
216 np_shape = tens_data.ShapeAsNumpy()
217 shape = list(np_shape) if type(np_shape) is np.ndarray else []
218 tens_dtype = tens_data.Type()
219 dtype = datatype_map[tens_dtype]
220
221 tens = Tensor(shape, dtype, name)
222
223 # Initialize quantization parameters
224 tens.quantization = QuantizationParameters()
225
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200226 if dtype == DataType.uint8:
227 tens.quantization.quant_min = 0
228 tens.quantization.quant_max = (1 << dtype.bits) - 1
Patrik Gustavssond15866c2021-08-10 13:56:34 +0200229 elif dtype in (DataType.int8, DataType.int16, DataType.int32, DataType.int48):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200230 tens.quantization.quant_min = -(1 << (dtype.bits - 1))
231 tens.quantization.quant_max = (1 << (dtype.bits - 1)) - 1
232
233 tens.values = None
Patrik Gustavssond15866c2021-08-10 13:56:34 +0200234
235 data_length = tens_data.DataLength()
236 if data_length != 0:
237 data_as_numpy = tens_data.DataAsNumpy()
238 if tens_dtype in datatype_map_numpy:
239 np_dtype = datatype_map_numpy[tens_dtype]
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200240
241 # TOSA pads the tensor data
242 shape_elements = shape_num_elements(shape)
243 values = np.array(data_as_numpy.view(np_dtype))
244 values = values[0:shape_elements]
245 tens.values = values.reshape(shape)
Patrik Gustavssond15866c2021-08-10 13:56:34 +0200246 else:
247 # int48 is only expected as an accumulated data/output format, int4 not supported
248 print(f"Error: unsupported/unexpected Tensor type {dtype}, with data")
249 assert False
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200250
251 return tens
252
253 def set_tensor_zp(self, tens, zp):
254 if tens.quantization.zero_point is None:
255 tens.quantization.zero_point = zp
256 elif tens.quantization.zero_point != zp:
Jonas Ohlsson25e700c2022-03-04 14:58:56 +0100257 print("Error: Setting tensor zp not possible, tensor already has different zero point")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200258 assert False
259
260
261class TosaGraph:
262 def __init__(self, filename, batch_size, feed_dict, output_node_names, initialisation_nodes):
263
264 self.op_times = {}
265 if batch_size is None:
266 batch_size = 1
267 self.batch_size = batch_size
268 self.name = os.path.splitext(os.path.basename(filename))[0]
269 self.initialisation_nodes = initialisation_nodes
270
271 with open(filename, "rb") as f:
272 buf = bytearray(f.read())
273
274 try:
275 parsing_step = "parsing root"
276 tosa_graph = TG.GetRootAsTosaGraph(buf, 0)
277
278 parsing_step = "parsing version"
279 self.check_version(tosa_graph)
280
281 parsing_step = "parsing blocks length"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200282 self.subgraphs = []
283 for b_idx in range(tosa_graph.BlocksLength()):
284 parsing_step = f"parsing block {b_idx}"
Patrik Gustavssond15866c2021-08-10 13:56:34 +0200285 self.subgraphs.append(TosaSubgraph(self, tosa_graph.Blocks(b_idx)))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200286
287 self.nng = Graph(self.name, self.batch_size)
288 for tosa_sg in self.subgraphs:
289 sg = Subgraph(tosa_sg.name)
290 sg.original_inputs = tosa_sg.inputs # Preserve the original input order
291 sg.output_tensors = tosa_sg.outputs
292 self.nng.subgraphs.append(sg)
293
294 except (struct.error, TypeError, RuntimeError) as e:
295 print(f'Error: Invalid .tosa file. Got "{e}" while {parsing_step}.')
296 sys.exit(1)
297
298 def check_version(self, tosa_graph):
299 version = tosa_graph.Version()
300 version_str = f"{version._major()}.{version._minor()}.{version._patch()}"
301 if version_str != "0.22.0":
302 print(f"Unsupported TOSA version: {version_str}")
303 assert False
304
305
306def read_tosa(filename, batch_size, feed_dict, output_node_names, initialisation_nodes):
307 tosa_graph = TosaGraph(filename, batch_size, feed_dict, output_node_names, initialisation_nodes)
308 nng = tosa_graph.nng
309 nng.refresh_after_modification()
310 return nng