blob: 5ffefade9d355b55e129790e2c501093f0ea78d1 [file] [log] [blame]
Rob Elliott78b94122024-01-25 13:05:16 +00001# SPDX-FileCopyrightText: Copyright 2021-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020017# Description:
18# Functions used to read from a TOSA format file.
19import os.path
20import struct
21import sys
22
23import numpy as np
24
25from .nn_graph import Graph
26from .nn_graph import Subgraph
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +020027from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020028from .operation import Op
29from .operation import Operation
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +020030from .reader_util import align_tensor_indices_to_nng
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020031from .reader_util import clone_and_reshape_tensor
32from .reader_util import decode_str
33from .reader_util import fixup_tensors
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020034from .shape4d import Shape4D
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020035from .tensor import QuantizationParameters
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020036from .tensor import shape_num_elements
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020037from .tensor import Tensor
38from .tflite_mapping import DataType
Patrik Gustavssondf995102021-08-23 15:33:59 +020039from .tosa.Op import Op as TosaOp
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020040from .tosa.TosaGraph import TosaGraph as TG
41from .tosa_mapping import datatype_map
Patrik Gustavssond15866c2021-08-10 13:56:34 +020042from .tosa_mapping import datatype_map_numpy
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020043from .tosa_mapping import tosa_operator_map
44from .tosa_mapping import unsupported_tosa_operators
45
46
47class TosaSubgraph:
Patrik Gustavssond15866c2021-08-10 13:56:34 +020048 def __init__(self, graph, block):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020049 self.graph = graph
50 self.name = decode_str(block.Name())
51
52 self.tensors = []
53 for idx in range(block.TensorsLength()):
Patrik Gustavssond15866c2021-08-10 13:56:34 +020054 self.tensors.append(self.parse_tensor(block.Tensors(idx)))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020055
56 for idx in range(block.OperatorsLength()):
57 self.parse_operator(idx, block.Operators(idx))
58
59 # Get the subgraph inputs and outputs
60 self.inputs = self.get_sg_inputs_remove_duplicates(block)
61 self.outputs = self.get_sg_outputs_remove_duplicates(block)
62 fixup_tensors(self.inputs, self.tensors)
63
64 def get_sg_inputs_remove_duplicates(self, block):
65 inputs = []
66 for idx in range(block.InputsLength()):
67 tens_data = block.Inputs(idx)
68 self.add_not_duplicate(tens_data, inputs, "input")
69 return inputs
70
71 def get_sg_outputs_remove_duplicates(self, block):
72 outputs = []
73 for idx in range(block.OutputsLength()):
74 tens_data = block.Outputs(idx)
75 self.add_not_duplicate(tens_data, outputs, "output")
76 return outputs
77
78 def add_not_duplicate(self, tens_data, tensors, warning_str):
79 name = decode_str(tens_data)
80 tensor = self.get_tensor_by_name(name)
81 if tensor not in tensors:
82 tensors.append(tensor)
83 else:
84 print(f"Warning: Subgraph {warning_str} tensor ({tensor}) already seen. Removing the duplicate.")
85
86 def get_tensor_by_name(self, name):
87 for tens in self.tensors:
88 if tens.name == name:
89 return tens
90 return None
91
92 def parse_operator(self, op_index, op_data):
93 op_code = op_data.Op()
94 if op_code in unsupported_tosa_operators:
95 print("Unsupported Operator", op_code)
Rob Elliott00a15db2023-08-17 14:27:06 +000096 for opname in dir(TosaOp):
97 if op_code == getattr(TosaOp, opname):
98 print(f" {opname}")
Patrik Gustavssondf995102021-08-23 15:33:59 +020099 return
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200100
101 op_type, attr_serializer, quant_serializer, indices = tosa_operator_map[op_code]
102 inputs = []
103 outputs = []
104 for idx in range(op_data.InputsLength()):
Rob Elliott00a15db2023-08-17 14:27:06 +0000105 input = decode_str(op_data.Inputs(idx))
106 input_tens = self.get_tensor_by_name(input)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200107 inputs.append(input_tens)
Rob Elliott00a15db2023-08-17 14:27:06 +0000108 if input_tens is None:
109 print(f"could not find named input tensor {input}::{input_tens}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200110 assert input_tens is not None
111
112 for idx in range(op_data.OutputsLength()):
Rob Elliott00a15db2023-08-17 14:27:06 +0000113 output = decode_str(op_data.Outputs(idx))
114 output_tens = self.get_tensor_by_name(output)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200115 outputs.append(output_tens)
Rob Elliott00a15db2023-08-17 14:27:06 +0000116 if output_tens is None:
117 print(f"could not find named output tensor {output}::{output_tens}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200118 assert output_tens is not None
119
120 name = "unknown_op_name"
121 if len(outputs):
122 name = outputs[0].name
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200123 inputs = align_tensor_indices_to_nng(op_type, indices, inputs)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200124 op = Operation(op_type, name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200125 op.op_index = op_index
126 op.inputs = inputs
127 op.outputs = outputs
128
129 for out in op.outputs:
130 out.ops = [op]
131
132 # TODO Transpose_conv and conv3d
133 if op.type.is_depthwise_conv2d_op() or op.type.is_conv2d_op() or op.type == Op.FullyConnected:
Oscar Anderssonb90666d2024-02-29 14:35:58 +0100134
135 def _remove_producing_identity_op(prod_op):
136 # find the producing op that is not an identity op and return it
137 while prod_op.type == Op.Identity:
138 prod_op = prod_op.inputs[0].ops[0] # get previous op
139 return prod_op
140
141 def _check_and_get_connection(prod_op, tens):
142 # check weight producing op can be connected to the weight tensor
143 assert len(prod_op.outputs) == 1
144 assert tens.shape == prod_op.outputs[0].shape
145 # only need to connect the current op connection as the tensor consuming connections haven't been
146 # initialised yet
147 return prod_op.outputs[0]
148
149 # remove identity ops directly connected to the weight input of conv like ops
150 weights_producer_op = _remove_producing_identity_op(inputs[1].ops[0])
151 inputs[1] = _check_and_get_connection(weights_producer_op, inputs[1]) # update connection
152
153 if weights_producer_op.type == Op.Transpose:
154 # remove transpose op such that the weight op will a const op
155 transpose_op = weights_producer_op
156 # remove identity ops directly connected to the input of the transpose op
157 transpose_producer_op = _remove_producing_identity_op(transpose_op.inputs[0].ops[0])
158 transpose_op.inputs[0] = _check_and_get_connection(
159 transpose_producer_op, transpose_op.inputs[0]
160 ) # update connection
161
162 perms = transpose_op.attrs["perms"]
163 inputs[1] = clone_and_reshape_tensor(transpose_op.inputs[0], perms, False)
164
165 if weights_producer_op.type == Op.Reshape:
166 # remove reshape op such that the weight op will a const op
167 reshape_op = weights_producer_op
168 # remove identity ops directly connected to the input of the reshape op
169 reshape_producer_op = _remove_producing_identity_op(reshape_op.inputs[0].ops[0])
170 reshape_op.inputs[0] = _check_and_get_connection(
171 reshape_producer_op, reshape_op.inputs[0]
172 ) # update connection
173
174 tens = reshape_op.inputs[0].clone("_reshape", False)
175 tens.values = np.reshape(tens.values, reshape_op.ofm.shape)
176 tens.shape = reshape_op.ofm.shape
177 tens._original_shape = tens.shape
178 tens.bandwidth_shape = tens.shape
179 tens.storage_shape = tens.shape
180
181 tmp_op = Operation(Op.Const, tens.name)
182 tmp_op.set_output_tensor(tens)
183 inputs[1] = tens
184
185 assert inputs[1].values is not None
186
187 if op.type == Op.FullyConnected:
188 inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0), False)
189 elif op.type.is_conv2d_op():
190 inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0), False)
191 elif op.type.is_depthwise_conv2d_op():
Per Ã…strand92240e72024-03-25 22:30:12 +0100192 HWCM_to_HWOI = (0, 1, 3, 2)
193 inputs[1] = clone_and_reshape_tensor(inputs[1], HWCM_to_HWOI, False)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200194 if op.type.needs_bias() and len(inputs) <= op_type.info.indices.biases[0]:
195 # No Bias tensor
196 inputs.append(None)
197 if inputs[-1] and inputs[-1].values is not None:
198 # Since bias tensor is used for both bias and scale,
199 # a clone with a unique equivalence_id is needed
200 inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,), True)
201
Oscar Anderssonb90666d2024-02-29 14:35:58 +0100202 op.explicit_scaling = ExplicitScaling(False, [0], [1]) # no scaling
203
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200204 if attr_serializer is not None:
205 op.attrs = attr_serializer.deserialize(op_data)
206
Oscar Anderssonb90666d2024-02-29 14:35:58 +0100207 if "pad" in op.attrs:
208 op.attrs["padding"] = op.attrs["pad"] # attribute was renamed to padding
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200209 padding = op.attrs["padding"] # [top, bottom, left, right]
210 op.attrs["explicit_padding"] = (
211 padding[0],
212 padding[2],
213 padding[1],
214 padding[3],
215 ) # [top, left, bottom, right]
216 if "stride" in op.attrs:
217 stride = op.attrs["stride"]
218 if len(stride) == 2:
219 op.attrs["strides"] = (1, stride[0], stride[1], 1)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200220 del op.attrs["stride"]
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200221 else:
222 # TODO CONV3D more to be done....
223 print("Unsupported kernel dimensions: ", len(stride))
224 assert False
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200225 if "dilation" in op.attrs:
226 dilation = op.attrs["dilation"]
227 if len(dilation) == 2:
228 op.attrs["dilation"] = (1, dilation[0], dilation[1], 1)
229 elif len(dilation) == 3:
230 # TODO CONV3D more to be done....
231 op.attrs["dilation"] = (dilation[0], dilation[1], dilation[2], 1)
232 if "kernel" in op.attrs:
233 kernel = op.attrs["kernel"]
234 if len(kernel) == 2:
235 op.attrs["ksize"] = (1, kernel[0], kernel[1], 1)
236 else:
237 # TODO CONV3D more to be done....
238 print("Unsupported kernel dimensions: ", len(kernel))
239 assert False
Patrik Gustavssonb081d672021-08-25 13:49:25 +0200240 if "shift" in op.attrs and op.type == Op.Mul:
241 shift = op.attrs["shift"]
242 if shift != 0:
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200243 op.explicit_scaling = ExplicitScaling(False, [shift], [1])
Patrik Gustavssondf995102021-08-23 15:33:59 +0200244 if op.type.is_depthwise_conv2d_op():
Per Ã…strand92240e72024-03-25 22:30:12 +0100245 assert op.weights.shape[-1] % op.ifm.shape[-1] == 0
246 depth_multiplier = op.weights.shape[-1] / op.ifm.shape[-1]
247 if depth_multiplier > 1:
248 assert op.ifm.shape[-1] == 1 and op.ofm.shape[-1] == depth_multiplier, (
249 "For depth multipliers > 1, IFM channels must be 1 and "
250 "OFM channels must be equal to the depth multiplier")
251 op.attrs["depth_multiplier"] = depth_multiplier
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200252 if op.type == Op.SplitSliceRead:
Rob Elliott00a15db2023-08-17 14:27:06 +0000253 op.read_offsets[0] = Shape4D.from_list(list(op.attrs["start"]), 0)
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200254 op.read_shapes[0] = op.attrs["size"]
Patrik Gustavssondf995102021-08-23 15:33:59 +0200255
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200256 # TODO tensor zero points currently set here
257 # zero points part of Rescale operation, handled in tosa_graph_optimizer
Rob Elliott00a15db2023-08-17 14:27:06 +0000258 if "input_zp" in op.attrs:
259 self.set_tensor_zp(op.ifm, op.attrs["input_zp"])
260 if "weight_zp" in op.attrs:
261 self.set_tensor_zp(op.weights, op.attrs["weight_zp"])
262 if "output_zp" in op.attrs:
263 self.set_tensor_zp(op.ofm, op.attrs["output_zp"])
264 if "a_zp" in op.attrs:
265 self.set_tensor_zp(op.ifm, op.attrs["a_zp"])
266 if "b_zp" in op.attrs:
267 self.set_tensor_zp(op.ifm2, op.attrs["b_zp"])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200268
Patrik Gustavssond15866c2021-08-10 13:56:34 +0200269 def parse_tensor(self, tens_data):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200270 name = decode_str(tens_data.Name())
271 np_shape = tens_data.ShapeAsNumpy()
272 shape = list(np_shape) if type(np_shape) is np.ndarray else []
273 tens_dtype = tens_data.Type()
274 dtype = datatype_map[tens_dtype]
275
276 tens = Tensor(shape, dtype, name)
277
278 # Initialize quantization parameters
279 tens.quantization = QuantizationParameters()
280
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200281 if dtype == DataType.uint8:
282 tens.quantization.quant_min = 0
283 tens.quantization.quant_max = (1 << dtype.bits) - 1
Patrik Gustavssond15866c2021-08-10 13:56:34 +0200284 elif dtype in (DataType.int8, DataType.int16, DataType.int32, DataType.int48):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200285 tens.quantization.quant_min = -(1 << (dtype.bits - 1))
286 tens.quantization.quant_max = (1 << (dtype.bits - 1)) - 1
287
288 tens.values = None
Patrik Gustavssond15866c2021-08-10 13:56:34 +0200289
290 data_length = tens_data.DataLength()
291 if data_length != 0:
292 data_as_numpy = tens_data.DataAsNumpy()
293 if tens_dtype in datatype_map_numpy:
294 np_dtype = datatype_map_numpy[tens_dtype]
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200295
296 # TOSA pads the tensor data
297 shape_elements = shape_num_elements(shape)
298 values = np.array(data_as_numpy.view(np_dtype))
299 values = values[0:shape_elements]
300 tens.values = values.reshape(shape)
Patrik Gustavssond15866c2021-08-10 13:56:34 +0200301 else:
302 # int48 is only expected as an accumulated data/output format, int4 not supported
303 print(f"Error: unsupported/unexpected Tensor type {dtype}, with data")
304 assert False
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200305
306 return tens
307
308 def set_tensor_zp(self, tens, zp):
309 if tens.quantization.zero_point is None:
310 tens.quantization.zero_point = zp
311 elif tens.quantization.zero_point != zp:
Jonas Ohlsson25e700c2022-03-04 14:58:56 +0100312 print("Error: Setting tensor zp not possible, tensor already has different zero point")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200313 assert False
314
315
316class TosaGraph:
317 def __init__(self, filename, batch_size, feed_dict, output_node_names, initialisation_nodes):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200318 self.op_times = {}
319 if batch_size is None:
320 batch_size = 1
321 self.batch_size = batch_size
322 self.name = os.path.splitext(os.path.basename(filename))[0]
323 self.initialisation_nodes = initialisation_nodes
324
325 with open(filename, "rb") as f:
326 buf = bytearray(f.read())
327
328 try:
329 parsing_step = "parsing root"
330 tosa_graph = TG.GetRootAsTosaGraph(buf, 0)
331
332 parsing_step = "parsing version"
333 self.check_version(tosa_graph)
334
Rob Elliott00a15db2023-08-17 14:27:06 +0000335 parsing_step = "parsing single main region"
336 assert 1 == tosa_graph.RegionsLength()
337 assert b"main" == tosa_graph.Regions(0).Name()
338
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200339 parsing_step = "parsing blocks length"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200340 self.subgraphs = []
Rob Elliott00a15db2023-08-17 14:27:06 +0000341 for b_idx in range(tosa_graph.Regions(0).BlocksLength()):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200342 parsing_step = f"parsing block {b_idx}"
Rob Elliott00a15db2023-08-17 14:27:06 +0000343 self.subgraphs.append(TosaSubgraph(self, tosa_graph.Regions(0).Blocks(b_idx)))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200344
345 self.nng = Graph(self.name, self.batch_size)
346 for tosa_sg in self.subgraphs:
347 sg = Subgraph(tosa_sg.name)
348 sg.original_inputs = tosa_sg.inputs # Preserve the original input order
349 sg.output_tensors = tosa_sg.outputs
350 self.nng.subgraphs.append(sg)
351
352 except (struct.error, TypeError, RuntimeError) as e:
353 print(f'Error: Invalid .tosa file. Got "{e}" while {parsing_step}.')
354 sys.exit(1)
355
356 def check_version(self, tosa_graph):
357 version = tosa_graph.Version()
Rob Elliott00a15db2023-08-17 14:27:06 +0000358 version_str = f"{version._Major()}.{version._Minor()}.{version._Patch()}"
Johan Alfven31947ad2024-04-04 15:50:08 +0200359 if version_str not in ("0.80.0", "0.80.1"):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200360 print(f"Unsupported TOSA version: {version_str}")
361 assert False
362
363
364def read_tosa(filename, batch_size, feed_dict, output_node_names, initialisation_nodes):
365 tosa_graph = TosaGraph(filename, batch_size, feed_dict, output_node_names, initialisation_nodes)
366 nng = tosa_graph.nng
367 nng.refresh_after_modification()
368 return nng