blob: 364d9a63566ffdbeba9afecadcbbf8ed3a47af27 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Functions used to read from a TOSA format file.
18import os.path
19import struct
20import sys
21
22import numpy as np
23
24from .nn_graph import Graph
25from .nn_graph import Subgraph
26from .operation import Op
27from .operation import Operation
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +020028from .reader_util import align_tensor_indices_to_nng
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020029from .reader_util import clone_and_reshape_tensor
30from .reader_util import decode_str
31from .reader_util import fixup_tensors
32from .tensor import QuantizationParameters
33from .tensor import Tensor
34from .tflite_mapping import DataType
35from .tosa.TosaGraph import TosaGraph as TG
36from .tosa_mapping import datatype_map
37from .tosa_mapping import tosa_operator_map
38from .tosa_mapping import unsupported_tosa_operators
39
40
41class TosaSubgraph:
42 def __init__(self, file_path, graph, block):
43 self.graph = graph
44 self.name = decode_str(block.Name())
45
46 self.tensors = []
47 for idx in range(block.TensorsLength()):
48 self.tensors.append(self.parse_tensor(block.Tensors(idx), file_path))
49
50 for idx in range(block.OperatorsLength()):
51 self.parse_operator(idx, block.Operators(idx))
52
53 # Get the subgraph inputs and outputs
54 self.inputs = self.get_sg_inputs_remove_duplicates(block)
55 self.outputs = self.get_sg_outputs_remove_duplicates(block)
56 fixup_tensors(self.inputs, self.tensors)
57
58 def get_sg_inputs_remove_duplicates(self, block):
59 inputs = []
60 for idx in range(block.InputsLength()):
61 tens_data = block.Inputs(idx)
62 self.add_not_duplicate(tens_data, inputs, "input")
63 return inputs
64
65 def get_sg_outputs_remove_duplicates(self, block):
66 outputs = []
67 for idx in range(block.OutputsLength()):
68 tens_data = block.Outputs(idx)
69 self.add_not_duplicate(tens_data, outputs, "output")
70 return outputs
71
72 def add_not_duplicate(self, tens_data, tensors, warning_str):
73 name = decode_str(tens_data)
74 tensor = self.get_tensor_by_name(name)
75 if tensor not in tensors:
76 tensors.append(tensor)
77 else:
78 print(f"Warning: Subgraph {warning_str} tensor ({tensor}) already seen. Removing the duplicate.")
79
80 def get_tensor_by_name(self, name):
81 for tens in self.tensors:
82 if tens.name == name:
83 return tens
84 return None
85
86 def parse_operator(self, op_index, op_data):
87 op_code = op_data.Op()
88 if op_code in unsupported_tosa_operators:
89 print("Unsupported Operator", op_code)
90 assert False
91
92 op_type, attr_serializer, quant_serializer, indices = tosa_operator_map[op_code]
93 inputs = []
94 outputs = []
95 for idx in range(op_data.InputsLength()):
96 input_tens = self.get_tensor_by_name(decode_str(op_data.Inputs(idx)))
97 inputs.append(input_tens)
98 assert input_tens is not None
99
100 for idx in range(op_data.OutputsLength()):
101 output_tens = self.get_tensor_by_name(decode_str(op_data.Outputs(idx)))
102 outputs.append(output_tens)
103 assert output_tens is not None
104
105 name = "unknown_op_name"
106 if len(outputs):
107 name = outputs[0].name
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200108 inputs = align_tensor_indices_to_nng(op_type, indices, inputs)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200109 op = Operation(op_type, name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200110 op.op_index = op_index
111 op.inputs = inputs
112 op.outputs = outputs
113
114 for out in op.outputs:
115 out.ops = [op]
116
117 # TODO Transpose_conv and conv3d
118 if op.type.is_depthwise_conv2d_op() or op.type.is_conv2d_op() or op.type == Op.FullyConnected:
119 if inputs[1].values is not None:
120 if op.type == Op.FullyConnected:
121 inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0), False)
122 elif op.type.is_conv2d_op():
123 inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0), False)
124 elif op.type.is_depthwise_conv2d_op():
125 inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 0, 3), False)
126 if op.type.needs_bias() and len(inputs) <= op_type.info.indices.biases[0]:
127 # No Bias tensor
128 inputs.append(None)
129 if inputs[-1] and inputs[-1].values is not None:
130 # Since bias tensor is used for both bias and scale,
131 # a clone with a unique equivalence_id is needed
132 inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,), True)
133
134 if attr_serializer is not None:
135 op.attrs = attr_serializer.deserialize(op_data)
136
137 if "dilation" in op.attrs:
138 dilation = op.attrs["dilation"]
139 if len(dilation) == 2:
140 op.attrs["dilation"] = (1, dilation[0], dilation[1], 1)
141 elif len(dilation) == 3:
142 # TODO CONV3D more to be done....
143 op.attrs["dilation"] = (dilation[0], dilation[1], dilation[2], 1)
144 if "kernel" in op.attrs:
145 kernel = op.attrs["kernel"]
146 if len(kernel) == 2:
147 op.attrs["ksize"] = (1, kernel[0], kernel[1], 1)
148 else:
149 # TODO CONV3D more to be done....
150 print("Unsupported kernel dimensions: ", len(kernel))
151 assert False
152
153 if quant_serializer is not None:
154 quant_info = quant_serializer.deserialize(op_data)
155
156 # TODO tensor zero points currently set here
157 # zero points part of Rescale operation, handled in tosa_graph_optimizer
158 if "input_zp" in quant_info:
159 self.set_tensor_zp(op.ifm, quant_info["input_zp"])
160 if "weight_zp" in quant_info:
161 self.set_tensor_zp(op.weights, quant_info["weight_zp"])
162 if "ouput_zp" in quant_info:
163 self.set_tensor_zp(op.ofm, quant_info["output_zp"])
164 if "a_zp" in quant_info:
165 self.set_tensor_zp(op.ifm, quant_info["a_zp"])
166 if "b_zp" in quant_info:
167 self.set_tensor_zp(op.ifm2, quant_info["b_zp"])
168
169 def parse_tensor(self, tens_data, file_path):
170 name = decode_str(tens_data.Name())
171 np_shape = tens_data.ShapeAsNumpy()
172 shape = list(np_shape) if type(np_shape) is np.ndarray else []
173 tens_dtype = tens_data.Type()
174 dtype = datatype_map[tens_dtype]
175
176 tens = Tensor(shape, dtype, name)
177
178 # Initialize quantization parameters
179 tens.quantization = QuantizationParameters()
180
181 tens.quantization.scale_f32 = 1.0
182 if dtype == DataType.uint8:
183 tens.quantization.quant_min = 0
184 tens.quantization.quant_max = (1 << dtype.bits) - 1
185 elif dtype in (DataType.int8, DataType.int16, DataType.int32, DataType.int64):
186 tens.quantization.quant_min = -(1 << (dtype.bits - 1))
187 tens.quantization.quant_max = (1 << (dtype.bits - 1)) - 1
188
189 tens.values = None
190 if tens_data.NpyFilename() is not None:
191 try:
192 fname = decode_str(tens_data.NpyFilename())
193 tens.values = np.load(os.path.join(file_path, fname))
194 assert list(tens.values.shape) == tens.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200195 except (struct.error, TypeError, RuntimeError) as e:
196 print(f'Error: Invalid npy file. Got "{e}" ')
197 sys.exit(1)
198
199 return tens
200
201 def set_tensor_zp(self, tens, zp):
202 if tens.quantization.zero_point is None:
203 tens.quantization.zero_point = zp
204 elif tens.quantization.zero_point != zp:
205 print(f"Error: Setting tensor zp not possible, tensor already has different zero point")
206 assert False
207
208
209class TosaGraph:
210 def __init__(self, filename, batch_size, feed_dict, output_node_names, initialisation_nodes):
211
212 self.op_times = {}
213 if batch_size is None:
214 batch_size = 1
215 self.batch_size = batch_size
216 self.name = os.path.splitext(os.path.basename(filename))[0]
217 self.initialisation_nodes = initialisation_nodes
218
219 with open(filename, "rb") as f:
220 buf = bytearray(f.read())
221
222 try:
223 parsing_step = "parsing root"
224 tosa_graph = TG.GetRootAsTosaGraph(buf, 0)
225
226 parsing_step = "parsing version"
227 self.check_version(tosa_graph)
228
229 parsing_step = "parsing blocks length"
230 file_path = os.path.dirname(filename)
231 self.subgraphs = []
232 for b_idx in range(tosa_graph.BlocksLength()):
233 parsing_step = f"parsing block {b_idx}"
234 self.subgraphs.append(TosaSubgraph(file_path, self, tosa_graph.Blocks(b_idx)))
235
236 self.nng = Graph(self.name, self.batch_size)
237 for tosa_sg in self.subgraphs:
238 sg = Subgraph(tosa_sg.name)
239 sg.original_inputs = tosa_sg.inputs # Preserve the original input order
240 sg.output_tensors = tosa_sg.outputs
241 self.nng.subgraphs.append(sg)
242
243 except (struct.error, TypeError, RuntimeError) as e:
244 print(f'Error: Invalid .tosa file. Got "{e}" while {parsing_step}.')
245 sys.exit(1)
246
247 def check_version(self, tosa_graph):
248 version = tosa_graph.Version()
249 version_str = f"{version._major()}.{version._minor()}.{version._patch()}"
250 if version_str != "0.22.0":
251 print(f"Unsupported TOSA version: {version_str}")
252 assert False
253
254
255def read_tosa(filename, batch_size, feed_dict, output_node_names, initialisation_nodes):
256 tosa_graph = TosaGraph(filename, batch_size, feed_dict, output_node_names, initialisation_nodes)
257 nng = tosa_graph.nng
258 nng.refresh_after_modification()
259 return nng