blob: eb3171696280a20cbe893e676306cc2cb26dabfc [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Functions used to read from a TOSA format file.
18import os.path
19import struct
20import sys
21
22import numpy as np
23
24from .nn_graph import Graph
25from .nn_graph import Subgraph
26from .operation import Op
27from .operation import Operation
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +020028from .reader_util import align_tensor_indices_to_nng
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020029from .reader_util import clone_and_reshape_tensor
30from .reader_util import decode_str
31from .reader_util import fixup_tensors
32from .tensor import QuantizationParameters
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020033from .tensor import shape_num_elements
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020034from .tensor import Tensor
35from .tflite_mapping import DataType
36from .tosa.TosaGraph import TosaGraph as TG
37from .tosa_mapping import datatype_map
Patrik Gustavssond15866c2021-08-10 13:56:34 +020038from .tosa_mapping import datatype_map_numpy
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020039from .tosa_mapping import tosa_operator_map
40from .tosa_mapping import unsupported_tosa_operators
41
42
43class TosaSubgraph:
Patrik Gustavssond15866c2021-08-10 13:56:34 +020044 def __init__(self, graph, block):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020045 self.graph = graph
46 self.name = decode_str(block.Name())
47
48 self.tensors = []
49 for idx in range(block.TensorsLength()):
Patrik Gustavssond15866c2021-08-10 13:56:34 +020050 self.tensors.append(self.parse_tensor(block.Tensors(idx)))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020051
52 for idx in range(block.OperatorsLength()):
53 self.parse_operator(idx, block.Operators(idx))
54
55 # Get the subgraph inputs and outputs
56 self.inputs = self.get_sg_inputs_remove_duplicates(block)
57 self.outputs = self.get_sg_outputs_remove_duplicates(block)
58 fixup_tensors(self.inputs, self.tensors)
59
60 def get_sg_inputs_remove_duplicates(self, block):
61 inputs = []
62 for idx in range(block.InputsLength()):
63 tens_data = block.Inputs(idx)
64 self.add_not_duplicate(tens_data, inputs, "input")
65 return inputs
66
67 def get_sg_outputs_remove_duplicates(self, block):
68 outputs = []
69 for idx in range(block.OutputsLength()):
70 tens_data = block.Outputs(idx)
71 self.add_not_duplicate(tens_data, outputs, "output")
72 return outputs
73
74 def add_not_duplicate(self, tens_data, tensors, warning_str):
75 name = decode_str(tens_data)
76 tensor = self.get_tensor_by_name(name)
77 if tensor not in tensors:
78 tensors.append(tensor)
79 else:
80 print(f"Warning: Subgraph {warning_str} tensor ({tensor}) already seen. Removing the duplicate.")
81
82 def get_tensor_by_name(self, name):
83 for tens in self.tensors:
84 if tens.name == name:
85 return tens
86 return None
87
88 def parse_operator(self, op_index, op_data):
89 op_code = op_data.Op()
90 if op_code in unsupported_tosa_operators:
91 print("Unsupported Operator", op_code)
92 assert False
93
94 op_type, attr_serializer, quant_serializer, indices = tosa_operator_map[op_code]
95 inputs = []
96 outputs = []
97 for idx in range(op_data.InputsLength()):
98 input_tens = self.get_tensor_by_name(decode_str(op_data.Inputs(idx)))
99 inputs.append(input_tens)
100 assert input_tens is not None
101
102 for idx in range(op_data.OutputsLength()):
103 output_tens = self.get_tensor_by_name(decode_str(op_data.Outputs(idx)))
104 outputs.append(output_tens)
105 assert output_tens is not None
106
107 name = "unknown_op_name"
108 if len(outputs):
109 name = outputs[0].name
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200110 inputs = align_tensor_indices_to_nng(op_type, indices, inputs)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200111 op = Operation(op_type, name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200112 op.op_index = op_index
113 op.inputs = inputs
114 op.outputs = outputs
115
116 for out in op.outputs:
117 out.ops = [op]
118
119 # TODO Transpose_conv and conv3d
120 if op.type.is_depthwise_conv2d_op() or op.type.is_conv2d_op() or op.type == Op.FullyConnected:
121 if inputs[1].values is not None:
122 if op.type == Op.FullyConnected:
123 inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0), False)
124 elif op.type.is_conv2d_op():
125 inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0), False)
126 elif op.type.is_depthwise_conv2d_op():
127 inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 0, 3), False)
128 if op.type.needs_bias() and len(inputs) <= op_type.info.indices.biases[0]:
129 # No Bias tensor
130 inputs.append(None)
131 if inputs[-1] and inputs[-1].values is not None:
132 # Since bias tensor is used for both bias and scale,
133 # a clone with a unique equivalence_id is needed
134 inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,), True)
135
136 if attr_serializer is not None:
137 op.attrs = attr_serializer.deserialize(op_data)
138
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200139 if "padding" in op.attrs:
140 padding = op.attrs["padding"] # [top, bottom, left, right]
141 op.attrs["explicit_padding"] = (
142 padding[0],
143 padding[2],
144 padding[1],
145 padding[3],
146 ) # [top, left, bottom, right]
147 if "stride" in op.attrs:
148 stride = op.attrs["stride"]
149 if len(stride) == 2:
150 op.attrs["strides"] = (1, stride[0], stride[1], 1)
151 else:
152 # TODO CONV3D more to be done....
153 print("Unsupported kernel dimensions: ", len(stride))
154 assert False
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200155 if "dilation" in op.attrs:
156 dilation = op.attrs["dilation"]
157 if len(dilation) == 2:
158 op.attrs["dilation"] = (1, dilation[0], dilation[1], 1)
159 elif len(dilation) == 3:
160 # TODO CONV3D more to be done....
161 op.attrs["dilation"] = (dilation[0], dilation[1], dilation[2], 1)
162 if "kernel" in op.attrs:
163 kernel = op.attrs["kernel"]
164 if len(kernel) == 2:
165 op.attrs["ksize"] = (1, kernel[0], kernel[1], 1)
166 else:
167 # TODO CONV3D more to be done....
168 print("Unsupported kernel dimensions: ", len(kernel))
169 assert False
170
171 if quant_serializer is not None:
172 quant_info = quant_serializer.deserialize(op_data)
173
174 # TODO tensor zero points currently set here
175 # zero points part of Rescale operation, handled in tosa_graph_optimizer
176 if "input_zp" in quant_info:
177 self.set_tensor_zp(op.ifm, quant_info["input_zp"])
178 if "weight_zp" in quant_info:
179 self.set_tensor_zp(op.weights, quant_info["weight_zp"])
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200180 if "output_zp" in quant_info:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200181 self.set_tensor_zp(op.ofm, quant_info["output_zp"])
182 if "a_zp" in quant_info:
183 self.set_tensor_zp(op.ifm, quant_info["a_zp"])
184 if "b_zp" in quant_info:
185 self.set_tensor_zp(op.ifm2, quant_info["b_zp"])
186
Patrik Gustavssond15866c2021-08-10 13:56:34 +0200187 def parse_tensor(self, tens_data):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200188 name = decode_str(tens_data.Name())
189 np_shape = tens_data.ShapeAsNumpy()
190 shape = list(np_shape) if type(np_shape) is np.ndarray else []
191 tens_dtype = tens_data.Type()
192 dtype = datatype_map[tens_dtype]
193
194 tens = Tensor(shape, dtype, name)
195
196 # Initialize quantization parameters
197 tens.quantization = QuantizationParameters()
198
199 tens.quantization.scale_f32 = 1.0
200 if dtype == DataType.uint8:
201 tens.quantization.quant_min = 0
202 tens.quantization.quant_max = (1 << dtype.bits) - 1
Patrik Gustavssond15866c2021-08-10 13:56:34 +0200203 elif dtype in (DataType.int8, DataType.int16, DataType.int32, DataType.int48):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200204 tens.quantization.quant_min = -(1 << (dtype.bits - 1))
205 tens.quantization.quant_max = (1 << (dtype.bits - 1)) - 1
206
207 tens.values = None
Patrik Gustavssond15866c2021-08-10 13:56:34 +0200208
209 data_length = tens_data.DataLength()
210 if data_length != 0:
211 data_as_numpy = tens_data.DataAsNumpy()
212 if tens_dtype in datatype_map_numpy:
213 np_dtype = datatype_map_numpy[tens_dtype]
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200214
215 # TOSA pads the tensor data
216 shape_elements = shape_num_elements(shape)
217 values = np.array(data_as_numpy.view(np_dtype))
218 values = values[0:shape_elements]
219 tens.values = values.reshape(shape)
Patrik Gustavssond15866c2021-08-10 13:56:34 +0200220 else:
221 # int48 is only expected as an accumulated data/output format, int4 not supported
222 print(f"Error: unsupported/unexpected Tensor type {dtype}, with data")
223 assert False
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200224
225 return tens
226
227 def set_tensor_zp(self, tens, zp):
228 if tens.quantization.zero_point is None:
229 tens.quantization.zero_point = zp
230 elif tens.quantization.zero_point != zp:
231 print(f"Error: Setting tensor zp not possible, tensor already has different zero point")
232 assert False
233
234
235class TosaGraph:
236 def __init__(self, filename, batch_size, feed_dict, output_node_names, initialisation_nodes):
237
238 self.op_times = {}
239 if batch_size is None:
240 batch_size = 1
241 self.batch_size = batch_size
242 self.name = os.path.splitext(os.path.basename(filename))[0]
243 self.initialisation_nodes = initialisation_nodes
244
245 with open(filename, "rb") as f:
246 buf = bytearray(f.read())
247
248 try:
249 parsing_step = "parsing root"
250 tosa_graph = TG.GetRootAsTosaGraph(buf, 0)
251
252 parsing_step = "parsing version"
253 self.check_version(tosa_graph)
254
255 parsing_step = "parsing blocks length"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200256 self.subgraphs = []
257 for b_idx in range(tosa_graph.BlocksLength()):
258 parsing_step = f"parsing block {b_idx}"
Patrik Gustavssond15866c2021-08-10 13:56:34 +0200259 self.subgraphs.append(TosaSubgraph(self, tosa_graph.Blocks(b_idx)))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200260
261 self.nng = Graph(self.name, self.batch_size)
262 for tosa_sg in self.subgraphs:
263 sg = Subgraph(tosa_sg.name)
264 sg.original_inputs = tosa_sg.inputs # Preserve the original input order
265 sg.output_tensors = tosa_sg.outputs
266 self.nng.subgraphs.append(sg)
267
268 except (struct.error, TypeError, RuntimeError) as e:
269 print(f'Error: Invalid .tosa file. Got "{e}" while {parsing_step}.')
270 sys.exit(1)
271
272 def check_version(self, tosa_graph):
273 version = tosa_graph.Version()
274 version_str = f"{version._major()}.{version._minor()}.{version._patch()}"
275 if version_str != "0.22.0":
276 print(f"Unsupported TOSA version: {version_str}")
277 assert False
278
279
280def read_tosa(filename, batch_size, feed_dict, output_node_names, initialisation_nodes):
281 tosa_graph = TosaGraph(filename, batch_size, feed_dict, output_node_names, initialisation_nodes)
282 nng = tosa_graph.nng
283 nng.refresh_after_modification()
284 return nng