blob: 2d03d49635166aa63b61a45ce80eb2591937f3e4 [file] [log] [blame]
Jerry Ge1eb85042023-01-06 14:19:14 -08001# Copyright (c) 2020-2023, ARM Limited.
Kevin Chengfea5a372021-10-11 18:38:47 +00002#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
Kevin Chengfea5a372021-10-11 18:38:47 +000015import os
Kevin Chengfea5a372021-10-11 18:38:47 +000016import json
17import flatbuffers
18import numpy as np
19import struct
Jeremy Johnson9b225172021-12-14 16:34:47 +000020from enum import IntEnum, unique
Kevin Chengfea5a372021-10-11 18:38:47 +000021from tosa import (
22 TosaGraph,
Jerry Ge1eb85042023-01-06 14:19:14 -080023 TosaRegion,
Kevin Chengfea5a372021-10-11 18:38:47 +000024 TosaBasicBlock,
25 TosaTensor,
26 TosaOperator,
Kevin Chengfea5a372021-10-11 18:38:47 +000027 Version,
28)
Jeremy Johnson9b225172021-12-14 16:34:47 +000029import tosa.DType as TosaDType
30import tosa.Op as TosaOp
Kevin Chengfea5a372021-10-11 18:38:47 +000031
Kevin Chenge6563f52021-10-20 12:12:02 -070032# Keep version number in sync with the version default value with schema/tosa.fbs
Kevin Chengb97cb1d2021-10-14 11:53:39 -070033TOSA_VERSION_MAJOR = 0
Eric Kunze6388a092022-12-07 21:59:31 +000034TOSA_VERSION_MINOR = 51
Kevin Chengb97cb1d2021-10-14 11:53:39 -070035TOSA_VERSION_PATCH = 0
Eric Kunze6388a092022-12-07 21:59:31 +000036TOSA_VERSION_DRAFT = True
Jeremy Johnson9b225172021-12-14 16:34:47 +000037TOSA_VERSION = [
38 TOSA_VERSION_MAJOR,
39 TOSA_VERSION_MINOR,
40 TOSA_VERSION_PATCH,
41 TOSA_VERSION_DRAFT,
42]
Eric Kunzee6596402022-06-09 21:27:36 +000043
44# File identifier needs to be kept in sync with schema/tosa.fbs
45TOSA_GRAPH_IDENTIFIER = b"\x54\x4F\x53\x41"
46
Kevin Chengfea5a372021-10-11 18:38:47 +000047# With the way flatc generates its python types, there is no programatic way
48# to get string names for the integer types. Manually maintain a string table
49# here.
Jeremy Johnson9b225172021-12-14 16:34:47 +000050DType = TosaDType.DType()
Kevin Chengfea5a372021-10-11 18:38:47 +000051DTypeNames = [
52 "UNKNOWN",
53 "BOOL",
54 "UINT8",
55 "INT4",
56 "INT8",
57 "INT16",
58 "INT32",
59 "INT48",
Jeremy Johnsone1072a92022-09-27 12:44:11 +010060 "FP32",
Jeremy Johnson41027732022-05-25 17:52:29 +010061 "UINT16",
James Ward485a11d2022-08-05 13:48:37 +010062 "FP16",
James Ward34a62792022-10-18 17:27:40 +010063 "BF16",
Kevin Chengfea5a372021-10-11 18:38:47 +000064]
65
66ByteMask = np.uint64(0xFF)
67
68
69def dtype_str_to_val(name):
70
71 for i in range(len(DTypeNames)):
72 if name.casefold() == DTypeNames[i].casefold():
73 return i
74 raise Exception("Unable to parse DType name {}".format(name))
75
76
77class TosaSerializerUnion:
78 """This class handles encapsulating and serializing union types into flatbuffers"""
79
80 def __init__(self):
81
Jeremy Johnson9b225172021-12-14 16:34:47 +000082 # A tuple of the start and end functions.
83 # Set by the options constructors below
Kevin Chengfea5a372021-10-11 18:38:47 +000084 self.optFcns = None
85
Jeremy Johnson9b225172021-12-14 16:34:47 +000086 # The type from the tosa.Options enumeration.
87 # Set by the options constructors below.
Kevin Chengfea5a372021-10-11 18:38:47 +000088 self.utype = None
89
90 # Each of these lists is a tuple of the add function and the
91 # value being added. Set by the options constructors below.
92 self.ints = []
93 self.bools = []
94 self.floats = []
95 self.strings = []
TatWai Chong49b1ca62022-06-10 01:49:13 -070096 self.int16vecs = []
Kevin Chengfea5a372021-10-11 18:38:47 +000097 self.intvecs = []
98 self.fpvecs = []
99
100 def serialize(self, builder):
101
102 # We have to build strings and vectors first
103 strList = []
104 intVecList = []
105 fpVecList = []
106
107 for fcn, val in self.strings:
108 strList.append((fcn, builder.CreateString(val)))
109
110 for fcn, val in self.intvecs:
111 intVecList.append((fcn, TosaSerializer.serializeInt32Vec(builder, val)))
112
TatWai Chong49b1ca62022-06-10 01:49:13 -0700113 for fcn, val in self.int16vecs:
114 intVecList.append((fcn, TosaSerializer.serializeInt16Vec(builder, val)))
115
Kevin Chengfea5a372021-10-11 18:38:47 +0000116 for fcn, val in self.fpvecs:
117 fpVecList.append((fcn, TosaSerializer.serializeFpVec(builder, val)))
118
119 startFcn, endFcn = self.optFcns
120
121 # Then serialize the options object from the list of primitives and
122 # other serialized values
123 startFcn(builder)
124 for fcn, val in self.ints:
125 fcn(builder, val)
126
127 for fcn, val in self.bools:
128 fcn(builder, val)
129
130 for fcn, val in self.floats:
131 fcn(builder, val)
132
133 for fcn, val in strList:
134 fcn(builder, val)
135
136 for fcn, val in intVecList:
137 fcn(builder, val)
138
139 for fcn, val in fpVecList:
140 fcn(builder, val)
141
142 return endFcn(builder)
143
144
145class TosaSerializerAttribute(TosaSerializerUnion):
146 """This class handles encapsulating all of the enumerated types for attributes"""
147
148 def __init__(self):
149 super().__init__()
150
James Ward485a11d2022-08-05 13:48:37 +0100151 def PoolAttribute(
152 self,
153 kernel,
154 stride,
155 pad,
156 input_zp,
157 output_zp,
158 accum_dtype,
159 ):
Kevin Chengfea5a372021-10-11 18:38:47 +0000160 from tosa import PoolAttribute as a, Attribute
161
162 self.utype = Attribute.Attribute().PoolAttribute
163
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800164 self.optFcns = (a.Start, a.End)
TatWai Chong7be71652022-05-10 17:26:20 -0700165 self.intvecs.append((a.AddPad, pad))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800166 self.intvecs.append((a.AddKernel, kernel))
167 self.intvecs.append((a.AddStride, stride))
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000168 self.ints.append((a.AddInputZp, input_zp))
169 self.ints.append((a.AddOutputZp, output_zp))
James Ward485a11d2022-08-05 13:48:37 +0100170 self.ints.append((a.AddAccumDtype, accum_dtype))
Kevin Chengfea5a372021-10-11 18:38:47 +0000171
James Ward485a11d2022-08-05 13:48:37 +0100172 def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp, accum_dtype):
Kevin Chengfea5a372021-10-11 18:38:47 +0000173 from tosa import ConvAttribute as a, Attribute
174
175 self.utype = Attribute.Attribute().ConvAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800176 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000177
TatWai Chong7be71652022-05-10 17:26:20 -0700178 self.intvecs.append((a.AddPad, pad))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800179 self.intvecs.append((a.AddStride, stride))
180 self.intvecs.append((a.AddDilation, dilation))
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000181 self.ints.append((a.AddInputZp, input_zp))
182 self.ints.append((a.AddWeightZp, weight_zp))
James Ward485a11d2022-08-05 13:48:37 +0100183 self.ints.append((a.AddAccumDtype, accum_dtype))
Kevin Chengfea5a372021-10-11 18:38:47 +0000184
James Ward485a11d2022-08-05 13:48:37 +0100185 def TransposeConvAttribute(
186 self, outpad, stride, output_shape, input_zp, weight_zp, accum_dtype
187 ):
Kevin Chengfea5a372021-10-11 18:38:47 +0000188 from tosa import TransposeConvAttribute as a, Attribute
189
190 self.utype = Attribute.Attribute().TransposeConvAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800191 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000192
Eric Kunze4c3537d2022-06-13 17:21:48 -0700193 self.intvecs.append((a.AddOutPad, outpad))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800194 self.intvecs.append((a.AddStride, stride))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800195 self.intvecs.append((a.AddOutputShape, output_shape))
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000196 self.ints.append((a.AddInputZp, input_zp))
197 self.ints.append((a.AddWeightZp, weight_zp))
James Ward485a11d2022-08-05 13:48:37 +0100198 self.ints.append((a.AddAccumDtype, accum_dtype))
Kevin Chengfea5a372021-10-11 18:38:47 +0000199
Kevin Cheng38d214c2021-10-15 15:49:19 -0700200 def PadAttribute(self, padding, pad_const_int, pad_const_fp):
201 from tosa import PadAttribute as a, Attribute
Kevin Chengfea5a372021-10-11 18:38:47 +0000202
Kevin Cheng38d214c2021-10-15 15:49:19 -0700203 self.utype = Attribute.Attribute().PadAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800204 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000205
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800206 self.intvecs.append((a.AddPadding, padding))
207 self.ints.append((a.AddPadConstInt, pad_const_int))
208 self.floats.append((a.AddPadConstFp, pad_const_fp))
Kevin Chengfea5a372021-10-11 18:38:47 +0000209
210 def AxisAttribute(self, axis):
211 from tosa import AxisAttribute as a, Attribute
212
213 self.utype = Attribute.Attribute().AxisAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800214 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000215
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800216 self.ints.append((a.AddAxis, axis))
Kevin Chengfea5a372021-10-11 18:38:47 +0000217
TatWai Chong7be71652022-05-10 17:26:20 -0700218 def ReshapeAttribute(self, new_shape):
Kevin Chengfea5a372021-10-11 18:38:47 +0000219 from tosa import ReshapeAttribute as a, Attribute
220
221 self.utype = Attribute.Attribute().ReshapeAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800222 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000223
TatWai Chong7be71652022-05-10 17:26:20 -0700224 self.intvecs.append((a.AddNewShape, new_shape))
Kevin Chengfea5a372021-10-11 18:38:47 +0000225
TatWai Chong7be71652022-05-10 17:26:20 -0700226 def SliceAttribute(self, start, size):
Kevin Chengfea5a372021-10-11 18:38:47 +0000227 from tosa import SliceAttribute as a, Attribute
228
229 self.utype = Attribute.Attribute().SliceAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800230 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000231
TatWai Chong7be71652022-05-10 17:26:20 -0700232 self.intvecs.append((a.AddStart, start))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800233 self.intvecs.append((a.AddSize, size))
Kevin Chengfea5a372021-10-11 18:38:47 +0000234
235 def TileAttribute(self, multiples):
236 from tosa import TileAttribute as a, Attribute
237
238 self.utype = Attribute.Attribute().TileAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800239 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000240
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800241 self.intvecs.append((a.AddMultiples, multiples))
Kevin Chengfea5a372021-10-11 18:38:47 +0000242
TatWai Chong49b1ca62022-06-10 01:49:13 -0700243 def ResizeAttribute(self, scale, offset, border, mode):
Kevin Chengfea5a372021-10-11 18:38:47 +0000244 from tosa import ResizeAttribute as a, Attribute
245
246 self.utype = Attribute.Attribute().ResizeAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800247 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000248
TatWai Chong49b1ca62022-06-10 01:49:13 -0700249 self.int16vecs.append((a.AddScale, scale))
250 self.int16vecs.append((a.AddOffset, offset))
251 self.int16vecs.append((a.AddBorder, border))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800252 self.ints.append((a.AddMode, mode))
Kevin Chengfea5a372021-10-11 18:38:47 +0000253
254 def ClampAttribute(self, minint, maxint, minfp, maxfp):
255 from tosa import ClampAttribute as a, Attribute
256
257 self.utype = Attribute.Attribute().ClampAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800258 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000259
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800260 self.ints.append((a.AddMinInt, minint))
261 self.ints.append((a.AddMaxInt, maxint))
Kevin Chengfea5a372021-10-11 18:38:47 +0000262
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800263 self.ints.append((a.AddMinFp, minfp))
264 self.ints.append((a.AddMaxFp, maxfp))
Kevin Chengfea5a372021-10-11 18:38:47 +0000265
266 def RescaleAttribute(
267 self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel
268 ):
269 from tosa import RescaleAttribute as a, Attribute
270
271 self.utype = Attribute.Attribute().RescaleAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800272 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000273
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800274 self.ints.append((a.AddInputZp, input_zp))
275 self.ints.append((a.AddOutputZp, output_zp))
276 self.intvecs.append((a.AddMultiplier, multiplier))
277 self.intvecs.append((a.AddShift, shift))
278 self.bools.append((a.AddScale32, scale32))
279 self.bools.append((a.AddDoubleRound, double_round))
280 self.bools.append((a.AddPerChannel, per_channel))
Kevin Chengfea5a372021-10-11 18:38:47 +0000281
282 def MulAttribute(self, shift):
283 from tosa import MulAttribute as a, Attribute
284
285 self.utype = Attribute.Attribute().MulAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800286 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000287
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800288 self.ints.append((a.AddShift, shift))
Kevin Chengfea5a372021-10-11 18:38:47 +0000289
290 def ArithmeticRightShiftAttribute(self, round):
291 from tosa import ArithmeticRightShiftAttribute as a, Attribute
292
293 self.utype = Attribute.Attribute().ArithmeticRightShiftAttribute
294 self.optFcns = (
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800295 a.Start,
296 a.End,
Kevin Chengfea5a372021-10-11 18:38:47 +0000297 )
298
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800299 self.bools.append((a.AddRound, round))
Kevin Chengfea5a372021-10-11 18:38:47 +0000300
Kevin Chengfea5a372021-10-11 18:38:47 +0000301 def CondIfAttribute(self, then_branch, else_branch):
302 from tosa import CondIfAttribute as a, Attribute
303
304 self.utype = Attribute.Attribute().CondIfAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800305 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000306
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800307 self.strings.append((a.AddThenBranch, then_branch))
308 self.strings.append((a.AddElseBranch, else_branch))
Kevin Chengfea5a372021-10-11 18:38:47 +0000309
310 def WhileLoopAttribute(self, cond_branch, body_branch):
311 from tosa import WhileLoopAttribute as a, Attribute
312
313 self.utype = Attribute.Attribute().WhileLoopAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800314 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000315
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800316 self.strings.append((a.AddCondBranch, cond_branch))
317 self.strings.append((a.AddBodyBranch, body_branch))
Kevin Chengfea5a372021-10-11 18:38:47 +0000318
TatWai Chong7be71652022-05-10 17:26:20 -0700319 def TransposeAttribute(self, perms):
Kevin Cheng38d214c2021-10-15 15:49:19 -0700320 from tosa import TransposeAttribute as a, Attribute
321
322 self.utype = Attribute.Attribute().TransposeAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800323 self.optFcns = (a.Start, a.End)
Kevin Cheng38d214c2021-10-15 15:49:19 -0700324
TatWai Chong7be71652022-05-10 17:26:20 -0700325 self.intvecs.append((a.AddPerms, perms))
Kevin Cheng38d214c2021-10-15 15:49:19 -0700326
327 def TableAttribute(self, table):
328 from tosa import TableAttribute as a, Attribute
329
330 self.utype = Attribute.Attribute().TableAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800331 self.optFcns = (a.Start, a.End)
Kevin Cheng38d214c2021-10-15 15:49:19 -0700332
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800333 self.intvecs.append((a.AddTable, table))
Kevin Chengfea5a372021-10-11 18:38:47 +0000334
James Ward485a11d2022-08-05 13:48:37 +0100335 def MatMulAttribute(self, A_zp, B_zp, accum_dtype):
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000336 from tosa import MatMulAttribute as a, Attribute
Jeremy Johnson9b225172021-12-14 16:34:47 +0000337
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000338 self.utype = Attribute.Attribute().MatMulAttribute
339 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000340
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000341 self.ints.append((a.AddAZp, A_zp))
342 self.ints.append((a.AddBZp, B_zp))
James Ward485a11d2022-08-05 13:48:37 +0100343 self.ints.append((a.AddAccumDtype, accum_dtype))
Kevin Chengfea5a372021-10-11 18:38:47 +0000344
James Ward485a11d2022-08-05 13:48:37 +0100345 def FullyConnectedAttribute(self, input_zp, weight_zp, accum_dtype):
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000346 from tosa import FullyConnectedAttribute as a, Attribute
Kevin Chengfea5a372021-10-11 18:38:47 +0000347
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000348 self.utype = Attribute.Attribute().FullyConnectedAttribute
349 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000350
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000351 self.ints.append((a.AddInputZp, input_zp))
352 self.ints.append((a.AddWeightZp, weight_zp))
James Ward485a11d2022-08-05 13:48:37 +0100353 self.ints.append((a.AddAccumDtype, accum_dtype))
Kevin Chengfea5a372021-10-11 18:38:47 +0000354
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000355 def NegateAttribute(self, input1_zp, output_zp):
356 from tosa import NegateAttribute as a, Attribute
Kevin Chengfea5a372021-10-11 18:38:47 +0000357
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000358 self.utype = Attribute.Attribute().NegateAttribute
359 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000360
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000361 self.ints.append((a.AddInput1Zp, input1_zp))
362 self.ints.append((a.AddOutputZp, output_zp))
Kevin Chengfea5a372021-10-11 18:38:47 +0000363
364
365class TosaSerializerTensor:
366 def __init__(
367 self,
368 name,
369 shape,
370 dtype,
371 data=None,
372 placeholderFilename=None,
373 ):
374 self.name = name
375
376 if isinstance(shape, np.ndarray):
377 shape = shape.astype(int).tolist()
378 shape = list(map(int, shape))
379
380 self.shape = shape
381 self.dtype = dtype
382
James Ward34a62792022-10-18 17:27:40 +0100383 if dtype == DType.FP32 or dtype == DType.BF16:
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100384 fntype = np.float32
James Ward485a11d2022-08-05 13:48:37 +0100385 elif dtype == DType.FP16:
386 fntype = np.float16
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100387 else:
388 fntype = int
389
Kevin Chengfea5a372021-10-11 18:38:47 +0000390 if isinstance(data, np.ndarray):
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100391 data = data.flatten().astype(fntype).tolist()
392 data = list(map(fntype, data))
Kevin Chengfea5a372021-10-11 18:38:47 +0000393 self.data = data
394 elif isinstance(data, list):
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100395 data = list(map(fntype, data))
Kevin Chengfea5a372021-10-11 18:38:47 +0000396 self.data = data
397 else:
398 self.data = None
399
400 # Filename for placeholder tensors. These get generated by the test generation
Jeremy Johnson9b225172021-12-14 16:34:47 +0000401 # process and are written to disk, but are considered input tensors by the
402 # network so they do not appear in the TOSA serialiazation. However, if we
403 # want to form a unit test around these input tensors, we can get the filename
404 # from here.
Kevin Chengfea5a372021-10-11 18:38:47 +0000405 self.placeholderFilename = placeholderFilename
406
407 def __str__(self):
Jerry Ge1eb85042023-01-06 14:19:14 -0800408 concatString = "TosaSerializerTensor name: {} shape: {} dtype: {}".format(
Kevin Chengfea5a372021-10-11 18:38:47 +0000409 self.name,
410 self.shape,
411 DTypeNames[self.dtype],
412 )
Jerry Ge1eb85042023-01-06 14:19:14 -0800413 return concatString
Kevin Chengfea5a372021-10-11 18:38:47 +0000414
415 def setDtype(self, dtype):
416 self.dtype = dtype
417
418 def serialize(self, builder):
419 fb_name = builder.CreateString(self.name)
420 fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
421 if self.data:
422 u8_data = list()
423 # little endianess
424 if self.dtype == DType.BOOL:
425 for val in self.data:
426 val_u8 = np.uint8(val)
427 u8_data.append(val_u8)
428 elif self.dtype == DType.INT4:
429 in_size = len(self.data)
430 out_size = (in_size + 1) // 2
431 for i in range(out_size):
432 val_0 = self.data[2 * i]
433 if (2 * i + 1) < in_size:
434 val_1 = self.data[2 * i + 1]
435 else:
436 val_1 = 0
437 val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4)
438 val_u8 = np.uint8(val_i8)
439 u8_data.append(val_u8)
440 elif self.dtype == DType.INT8:
441 for val in self.data:
442 val_u8 = np.uint8(val)
443 u8_data.append(val_u8)
444 elif self.dtype == DType.INT16:
445 for val in self.data:
446 val_u16 = np.uint16(val)
447 b0 = val_u16 & ByteMask
448 b1 = (val_u16 >> np.uint16(8)) & ByteMask
449 u8_data.extend([b0, b1])
450 elif self.dtype == DType.INT32:
451 for val in self.data:
452 val_u32 = np.uint32(val)
453 b0 = val_u32 & ByteMask
454 b1 = (val_u32 >> np.uint32(8)) & ByteMask
455 b2 = (val_u32 >> np.uint32(16)) & ByteMask
Kevin Cheng6b078ca2021-10-13 23:12:50 -0700456 b3 = (val_u32 >> np.uint32(24)) & ByteMask
Kevin Chengfea5a372021-10-11 18:38:47 +0000457 u8_data.extend([b0, b1, b2, b3])
458 elif self.dtype == DType.INT48:
459 for val in self.data:
460 val_u64 = np.uint64(val)
461 b0 = val_u64 & ByteMask
462 b1 = (val_u64 >> np.uint64(8)) & ByteMask
463 b2 = (val_u64 >> np.uint64(16)) & ByteMask
464 b3 = (val_u64 >> np.uint64(24)) & ByteMask
465 b4 = (val_u64 >> np.uint64(32)) & ByteMask
466 b5 = (val_u64 >> np.uint64(40)) & ByteMask
467 u8_data.extend([b0, b1, b2, b3, b4, b5])
James Ward485a11d2022-08-05 13:48:37 +0100468 elif self.dtype == DType.FP16:
469 np_arr = np.array(self.data, dtype=np.float16)
470 u8_data.extend(np_arr.view(np.uint8))
James Ward34a62792022-10-18 17:27:40 +0100471 elif self.dtype == DType.FP32 or self.dtype == DType.BF16:
Kevin Chengfea5a372021-10-11 18:38:47 +0000472 for val in self.data:
473 b = struct.pack("!f", val)
474 u8_data.extend([b[3], b[2], b[1], b[0]])
James Ward485a11d2022-08-05 13:48:37 +0100475 elif self.dtype == TosaDType.DType:
476 # Serialize DType enum data as uint8 bytes
477 for val in self.data:
478 np_arr = np.array(self.data, dtype=np.uint32)
479 u8_data.extend(np_arr.view(np.uint8))
Kevin Chengfea5a372021-10-11 18:38:47 +0000480 else:
481 raise Exception(
482 "unsupported data type {}".format(DTypeNames[self.dtype])
483 )
484 fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data)
485
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800486 TosaTensor.Start(builder)
487 TosaTensor.AddName(builder, fb_name)
488 TosaTensor.AddShape(builder, fb_shapes)
489 TosaTensor.AddType(builder, self.dtype)
Kevin Chengfea5a372021-10-11 18:38:47 +0000490 if self.data:
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800491 TosaTensor.AddData(builder, fb_data)
Kevin Chengfea5a372021-10-11 18:38:47 +0000492
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800493 return TosaTensor.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000494
495
496class TosaSerializerOperator:
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000497 def __init__(self, op, inputs, outputs, attributes=None):
Kevin Chengfea5a372021-10-11 18:38:47 +0000498 self.op = op
499 self.attributes = attributes
500 self.inputs = TosaSerializer.toList(inputs)
501 self.outputs = TosaSerializer.toList(outputs)
Kevin Chengfea5a372021-10-11 18:38:47 +0000502
503 def __str__(self):
Jerry Ge1eb85042023-01-06 14:19:14 -0800504 concatString = "Op {}\n----\n".format(self.op)
Kevin Chengfea5a372021-10-11 18:38:47 +0000505
506 for i in self.inputs:
Jerry Ge1eb85042023-01-06 14:19:14 -0800507 concatString = concatString + " Input: {}\n".format(i)
Kevin Chengfea5a372021-10-11 18:38:47 +0000508 for o in self.outputs:
Jerry Ge1eb85042023-01-06 14:19:14 -0800509 concatString = concatString + " Output: {}\n".format(o)
Kevin Chengfea5a372021-10-11 18:38:47 +0000510
Jerry Ge1eb85042023-01-06 14:19:14 -0800511 return concatString
Kevin Chengfea5a372021-10-11 18:38:47 +0000512
513 def serialize(self, builder):
514 fb_inputs = TosaSerializer.serializeStrVec(
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800515 builder, self.inputs, TosaOperator.StartInputsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000516 )
517 fb_outputs = TosaSerializer.serializeStrVec(
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800518 builder, self.outputs, TosaOperator.StartOutputsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000519 )
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000520 # Need to serialize attributes enums still
Kevin Chengfea5a372021-10-11 18:38:47 +0000521 if self.attributes is not None:
522 fb_attributes = self.attributes.serialize(builder)
523
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800524 TosaOperator.Start(builder)
525 TosaOperator.AddOp(builder, self.op)
526 TosaOperator.AddInputs(builder, fb_inputs)
527 TosaOperator.AddOutputs(builder, fb_outputs)
Kevin Chengfea5a372021-10-11 18:38:47 +0000528 if self.attributes is not None:
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800529 TosaOperator.AddAttributeType(builder, self.attributes.utype)
530 TosaOperator.AddAttribute(builder, fb_attributes)
Kevin Chengfea5a372021-10-11 18:38:47 +0000531
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800532 return TosaOperator.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000533
534
535class TosaSerializerBasicBlock:
Jerry Ge1eb85042023-01-06 14:19:14 -0800536 def __init__(self, name, pathPrefix, saveConstsToFile=False):
Kevin Chengfea5a372021-10-11 18:38:47 +0000537 self.name = name
Jerry Ge1eb85042023-01-06 14:19:14 -0800538 self.pathPrefix = pathPrefix
Kevin Chengfea5a372021-10-11 18:38:47 +0000539 self.operators = []
Jerry Ge1eb85042023-01-06 14:19:14 -0800540 self.saveConstsToFile = saveConstsToFile
Kevin Chengfea5a372021-10-11 18:38:47 +0000541
542 # Dict assures uniqueness, but allows us to look up by name
543 self.tensors = dict()
544
545 self.inputs = []
546 self.outputs = []
547
548 def addTensor(
549 self,
550 name,
551 shape,
552 dtype,
553 data=None,
554 placeholderFilename=None,
555 ):
Jeremy Johnson9b225172021-12-14 16:34:47 +0000556 if name not in self.tensors:
Kevin Chengfea5a372021-10-11 18:38:47 +0000557 self.tensors[name] = TosaSerializerTensor(
558 name, shape, dtype, data, placeholderFilename
559 )
560
561 return self.tensors[name]
562
563 def addInput(self, name):
564 self.inputs.append(name)
565
566 def addOutput(self, name):
567 self.outputs.append(name)
568
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000569 def addOperator(self, op, inputs, outputs, attributes=None):
570 self.operators.append(TosaSerializerOperator(op, inputs, outputs, attributes))
Kevin Chengfea5a372021-10-11 18:38:47 +0000571
572 def serialize(self, builder):
573 fb_name = builder.CreateString(self.name)
574 fbv_inputs = TosaSerializer.serializeStrVec(
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800575 builder, list(self.inputs), TosaBasicBlock.StartInputsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000576 )
577 fbv_outputs = TosaSerializer.serializeStrVec(
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800578 builder, list(self.outputs), TosaBasicBlock.StartOutputsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000579 )
580 fbv_tensors = TosaSerializer.serializeObjVec(
581 builder,
582 list(self.tensors.values()),
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800583 TosaBasicBlock.StartTensorsVector,
Kevin Chengfea5a372021-10-11 18:38:47 +0000584 )
585 fbv_operators = TosaSerializer.serializeObjVec(
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800586 builder, self.operators, TosaBasicBlock.StartOperatorsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000587 )
588
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800589 TosaBasicBlock.Start(builder)
590 TosaBasicBlock.AddName(builder, fb_name)
591 TosaBasicBlock.AddInputs(builder, fbv_inputs)
592 TosaBasicBlock.AddOutputs(builder, fbv_outputs)
593 TosaBasicBlock.AddTensors(builder, fbv_tensors)
594 TosaBasicBlock.AddOperators(builder, fbv_operators)
595 return TosaBasicBlock.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000596
597
Jerry Ge1eb85042023-01-06 14:19:14 -0800598class TosaSerializerRegion:
599 def __init__(self, name, pathPrefix, saveConstsToFile=False):
600 self.name = name
Kevin Chengfea5a372021-10-11 18:38:47 +0000601 self.basicBlocks = []
Kevin Chengfea5a372021-10-11 18:38:47 +0000602 self.currInputIdx = 0
603 self.currConstIdx = 0
604 self.currLayerIdx = 1
605 self.currResultIdx = 0
Jerry Ge1eb85042023-01-06 14:19:14 -0800606 self.pathPrefix = pathPrefix
607 self.saveConstsToFile = saveConstsToFile
Kevin Chengfea5a372021-10-11 18:38:47 +0000608
Jerry Ge1eb85042023-01-06 14:19:14 -0800609 def addBasicBlock(self, name, pathPrefix, saveConstsToFile):
610 self.currBasicBlock = TosaSerializerBasicBlock(
611 name, pathPrefix, saveConstsToFile
612 )
613 self.basicBlocks.append(self.currBasicBlock)
Kevin Chengfea5a372021-10-11 18:38:47 +0000614
Jerry Ge1eb85042023-01-06 14:19:14 -0800615 def serialize(self, builder):
616 fb_name = builder.CreateString(self.name)
617 fbv_basicBlocks = TosaSerializer.serializeObjVec(
618 builder, self.basicBlocks, TosaRegion.StartBlocksVector
619 )
620
621 TosaRegion.Start(builder)
622 TosaRegion.AddName(builder, fb_name)
623 TosaRegion.AddBlocks(builder, fbv_basicBlocks)
624 return TosaRegion.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000625
626 def addPlaceholder(self, shape, dtype, vals):
627 if not self.currBasicBlock:
628 raise Exception("addTensor called without valid basic block")
629
630 name = "input-{}".format(self.currInputIdx)
631 filename = "{}.npy".format(name)
632 self.currInputIdx = self.currInputIdx + 1
633
634 tens = self.currBasicBlock.addTensor(name, shape, dtype, None, filename)
635 # This is always an input to the block
636 self.currBasicBlock.addInput(name)
637
638 if vals is not None:
639 np.save(os.path.join(self.pathPrefix, filename), vals, False)
640
641 return tens
642
643 def addConst(self, shape, dtype, vals):
644 if not self.currBasicBlock:
645 raise Exception("addTensor called without valid basic block")
646
647 name = "const-{}".format(self.currInputIdx)
Kevin Chengfea5a372021-10-11 18:38:47 +0000648 self.currInputIdx = self.currInputIdx + 1
649
650 tens = self.currBasicBlock.addTensor(name, shape, dtype, vals)
651 # Add the operator now
Jeremy Johnson9b225172021-12-14 16:34:47 +0000652 self.currBasicBlock.addOperator(TosaOp.Op().CONST, [], name)
Kevin Chengfea5a372021-10-11 18:38:47 +0000653
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100654 if self.saveConstsToFile:
655 filename = "{}.npy".format(name)
656 np.save(os.path.join(self.pathPrefix, filename), vals, False)
657
Kevin Chengfea5a372021-10-11 18:38:47 +0000658 return tens
659
660 def addIntermediate(self, shape, dtype):
Kevin Chengfea5a372021-10-11 18:38:47 +0000661 if not self.currBasicBlock:
662 raise Exception("addTensor called without valid basic block")
663
664 name = "layer-{}".format(self.currLayerIdx)
665 self.currLayerIdx = self.currLayerIdx + 1
666
667 tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
668
669 return tens
670
671 def addInputTensor(self, tensor):
672 self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype)
673 self.currBasicBlock.addInput(tensor.name)
674
675 def addOutputTensor(self, tensor):
676 self.currBasicBlock.addOutput(tensor.name)
677
678 def addOutput(self, shape, dtype):
679 if not self.currBasicBlock:
680 raise Exception("addTensor called without valid basic block")
681
682 name = "result-{}".format(self.currResultIdx)
683 self.currResultIdx = self.currResultIdx + 1
684
685 tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
686 self.currBasicBlock.addOutput(name)
687 return tens
688
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000689 def addOperator(self, op, inputs, outputs, attributes=None):
Jeremy Johnson9b225172021-12-14 16:34:47 +0000690 if op == TosaOp.Op().CONST:
Kevin Chengfea5a372021-10-11 18:38:47 +0000691 raise Exception("Use addConstTensor() to add CONST ops")
692
693 return self.currBasicBlock.addOperator(
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000694 op,
695 inputs,
696 outputs,
697 attributes,
Kevin Chengfea5a372021-10-11 18:38:47 +0000698 )
699
Jerry Ge1eb85042023-01-06 14:19:14 -0800700
701@unique
702class TensorDir(IntEnum):
703 PLACEHOLDER = 0
704 CONST = 1
705 INTERMEDIATE = 2
706 RESULT = 3
707
708
709class TosaSerializer:
710 def __init__(self, pathPrefix, saveConstsToFile=False):
711 self.add_compat_methods()
712 # Get the global TOSA version if not already defined
713
714 self.builder = flatbuffers.Builder(0)
715
716 self.regions = []
717 self.startRegion("main", pathPrefix, saveConstsToFile)
718
719 # Enables inspection of constant data outside of graph
720 self.saveConstsToFile = saveConstsToFile
721
722 self.currRegion.addBasicBlock("main", pathPrefix, self.saveConstsToFile)
723
724 # Is this an illegal test that is expected to fail?
725 self.expectedReturnCode = 0
726 self.expectedFailure = False
727 self.expectedFailureDesc = ""
728
729 def __str__(self):
730 concatString = ""
731 for region in self.regions:
732 concatString = concatString + str(region)
733 return concatString
734
735 def addPlaceholder(self, shape, dtype, vals):
736 return self.currRegion.addPlaceholder(shape, dtype, vals)
737
738 def addConst(self, shape, dtype, vals):
739 return self.currRegion.addConst(shape, dtype, vals)
740
741 def addIntermediate(self, shape, dtype):
742 return self.currRegion.addIntermediate(shape, dtype)
743
744 def addInputTensor(self, tensor):
745 self.currRegion.addInputTensor(tensor)
746
747 def addOutputTensor(self, tensor):
748 self.currRegion.addOutputTensor(tensor)
749
750 def addOutput(self, shape, dtype):
751 return self.currRegion.addOutput(shape, dtype)
752
753 def addOperator(self, op, inputs, outputs, attributes=None):
754 return self.currRegion.addOperator(op, inputs, outputs, attributes)
755
Jeremy Johnson9b225172021-12-14 16:34:47 +0000756 def setExpectedReturnCode(self, val, fail, desc=""):
Kevin Chengfea5a372021-10-11 18:38:47 +0000757
758 self.expectedReturnCode = val
759 self.expectedFailureDesc = desc
Jeremy Johnson9b225172021-12-14 16:34:47 +0000760 self.expectedFailure = fail
Kevin Chengfea5a372021-10-11 18:38:47 +0000761
762 def serialize(self):
763
764 builder = self.builder
765
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800766 Version.Start(builder)
767 Version.Add_major(builder, TOSA_VERSION[0])
768 Version.Add_minor(builder, TOSA_VERSION[1])
769 Version.Add_patch(builder, TOSA_VERSION[2])
770 Version.Add_draft(builder, TOSA_VERSION[3])
771 version = Version.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000772
Jerry Ge1eb85042023-01-06 14:19:14 -0800773 fbv_region = TosaSerializer.serializeObjVec(
774 builder, self.regions, TosaGraph.StartRegionsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000775 )
776
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800777 TosaGraph.Start(builder)
778 TosaGraph.AddVersion(builder, version)
Jerry Ge1eb85042023-01-06 14:19:14 -0800779 TosaGraph.AddRegions(builder, fbv_region)
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800780 graph = TosaGraph.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000781
Eric Kunzee6596402022-06-09 21:27:36 +0000782 self.builder.Finish(graph, TOSA_GRAPH_IDENTIFIER)
Kevin Chengfea5a372021-10-11 18:38:47 +0000783 return self.builder.Output()
784
785 def writeJson(self, tosa_filename):
786 """Write a json test file so that it is fairly easy to pick up the test
787 and generate commands for third party tool"""
788 test_desc = dict()
789
790 test_desc["tosa_file"] = tosa_filename
791 ifm_name = []
792 ifm_file = []
793 ofm_name = []
794 ofm_file = []
795
Jerry Ge1eb85042023-01-06 14:19:14 -0800796 for region in self.regions:
797 for block in region.basicBlocks:
798 if block:
799 for i in block.inputs:
800 ifm_name.append(i)
801 ifm_file.append(block.tensors[i].placeholderFilename)
802 for o in block.outputs:
803 ofm_name.append(o)
804 # Make up an OFM filename here. One isn't generated until the
805 # reference tool is run, so any name is a good name
806 ofm_file.append("ref-{}.npy".format(o))
Kevin Chengfea5a372021-10-11 18:38:47 +0000807
808 test_desc["ifm_name"] = ifm_name
809 test_desc["ifm_file"] = ifm_file
810 test_desc["ofm_name"] = ofm_name
811 test_desc["ofm_file"] = ofm_file
812 test_desc["expected_return_code"] = self.expectedReturnCode
813 test_desc["expected_failure"] = self.expectedFailure
814 if self.expectedFailureDesc:
815 test_desc["expected_failure_desc"] = self.expectedFailureDesc
816
817 return json.dumps(test_desc, indent=" ")
818
Jerry Ge1eb85042023-01-06 14:19:14 -0800819 def startRegion(self, name, pathPrefix, saveConstsToFile):
820 self.currRegion = TosaSerializerRegion(name, pathPrefix, saveConstsToFile)
821 self.regions.append(self.currRegion)
Kevin Chengfea5a372021-10-11 18:38:47 +0000822
823 @staticmethod
824 def serializeStrVec(builder, vec, start_fcn):
825 fb_strs = [builder.CreateString(i) for i in vec]
826 start_fcn(builder, len(fb_strs))
827 for s in fb_strs[::-1]:
828 builder.PrependUOffsetTRelative(s)
Eric Kunzeae906de2022-05-30 22:40:47 -0700829 try:
830 return builder.EndVector()
831 except TypeError:
832 return builder.EndVector(len(vec))
Kevin Chengfea5a372021-10-11 18:38:47 +0000833
834 @staticmethod
835 def serializeUint8Vec(builder, vec):
836 builder.StartVector(1, len(vec), 8)
837 for v in vec[::-1]:
838 builder.PrependUint8(v)
839 try:
840 return builder.EndVector()
841 except TypeError:
842 return builder.EndVector(len(vec))
843
844 @staticmethod
TatWai Chong49b1ca62022-06-10 01:49:13 -0700845 def serializeInt16Vec(builder, vec):
846 builder.StartVector(2, len(vec), 4)
847 for v in vec[::-1]:
848 builder.PrependInt16(v)
849 try:
850 return builder.EndVector()
851 except TypeError:
852 return builder.EndVector(len(vec))
853
854 @staticmethod
Kevin Chengfea5a372021-10-11 18:38:47 +0000855 def serializeInt32Vec(builder, vec):
856 builder.StartVector(4, len(vec), 4)
857 for v in vec[::-1]:
858 builder.PrependInt32(v)
859 try:
860 return builder.EndVector()
861 except TypeError:
862 return builder.EndVector(len(vec))
863
864 @staticmethod
865 def serializeFpVec(builder, vec):
866 builder.StartVector(4, len(vec), 4)
867 for v in vec[::-1]:
868 builder.PrependFloat32(v)
869 try:
870 return builder.EndVector()
871 except TypeError:
872 return builder.EndVector(len(vec))
873
874 @staticmethod
875 def serializeObjVec(builder, vec, start_fcn):
876 serialized_vec = []
877 for v in vec[::-1]:
878 serialized_vec.append(v.serialize(builder))
879
880 start_fcn(builder, len(vec))
881 for v in serialized_vec:
882 builder.PrependUOffsetTRelative(v)
883 try:
884 return builder.EndVector()
885 except TypeError:
886 return builder.EndVector(len(vec))
887
888 @staticmethod
889 def toList(val):
890 if isinstance(val, list):
891 return val
892 else:
893 return [val]
Eric Kunzeae906de2022-05-30 22:40:47 -0700894
895 # Remove when switching to flatbuffers 2.0
896 # contains a mapping of the deprecated 1.12 method to the 2.0 version
897
898 def add_compat_methods(self):
899
900 from tosa import ArithmeticRightShiftAttribute
901
902 if not hasattr(ArithmeticRightShiftAttribute, "Start"):
903 ArithmeticRightShiftAttribute.Start = (
904 ArithmeticRightShiftAttribute.ArithmeticRightShiftAttributeStart
905 )
906 ArithmeticRightShiftAttribute.AddRound = (
907 ArithmeticRightShiftAttribute.ArithmeticRightShiftAttributeAddRound
908 )
909 ArithmeticRightShiftAttribute.End = (
910 ArithmeticRightShiftAttribute.ArithmeticRightShiftAttributeEnd
911 )
912 from tosa import AxisAttribute
913
914 if not hasattr(AxisAttribute, "Start"):
915 AxisAttribute.Start = AxisAttribute.AxisAttributeStart
916 AxisAttribute.AddAxis = AxisAttribute.AxisAttributeAddAxis
917 AxisAttribute.End = AxisAttribute.AxisAttributeEnd
918 from tosa import ClampAttribute
919
920 if not hasattr(ClampAttribute, "Start"):
921 ClampAttribute.Start = ClampAttribute.ClampAttributeStart
922 ClampAttribute.AddMinInt = ClampAttribute.ClampAttributeAddMinInt
923 ClampAttribute.AddMaxInt = ClampAttribute.ClampAttributeAddMaxInt
924 ClampAttribute.AddMinFp = ClampAttribute.ClampAttributeAddMinFp
925 ClampAttribute.AddMaxFp = ClampAttribute.ClampAttributeAddMaxFp
926 ClampAttribute.End = ClampAttribute.ClampAttributeEnd
927 from tosa import CondIfAttribute
928
929 if not hasattr(CondIfAttribute, "Start"):
930 CondIfAttribute.Start = CondIfAttribute.CondIfAttributeStart
931 CondIfAttribute.AddThenBranch = CondIfAttribute.CondIfAttributeAddThenBranch
932 CondIfAttribute.AddElseBranch = CondIfAttribute.CondIfAttributeAddElseBranch
933 CondIfAttribute.End = CondIfAttribute.CondIfAttributeEnd
934 from tosa import ConvAttribute
935
936 if not hasattr(ConvAttribute, "Start"):
937 ConvAttribute.Start = ConvAttribute.ConvAttributeStart
938 ConvAttribute.AddPad = ConvAttribute.ConvAttributeAddPad
939 ConvAttribute.StartPadVector = ConvAttribute.ConvAttributeStartPadVector
940 ConvAttribute.AddStride = ConvAttribute.ConvAttributeAddStride
941 ConvAttribute.StartStrideVector = (
942 ConvAttribute.ConvAttributeStartStrideVector
943 )
944 ConvAttribute.AddDilation = ConvAttribute.ConvAttributeAddDilation
945 ConvAttribute.StartDilationVector = (
946 ConvAttribute.ConvAttributeStartDilationVector
947 )
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000948 ConvAttribute.AddInputZp = ConvAttribute.ConvAttributeAddInputZp
949 ConvAttribute.AddWeightZp = ConvAttribute.ConvAttributeAddWeightZp
James Ward485a11d2022-08-05 13:48:37 +0100950 ConvAttribute.AddAccumDtype = ConvAttribute.ConvAttributeAddAccumDtype
Eric Kunzeae906de2022-05-30 22:40:47 -0700951 ConvAttribute.End = ConvAttribute.ConvAttributeEnd
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000952 from tosa import FullyConnectedAttribute
Eric Kunzeae906de2022-05-30 22:40:47 -0700953
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000954 if not hasattr(FullyConnectedAttribute, "Start"):
955 FullyConnectedAttribute.Start = (
956 FullyConnectedAttribute.FullyConnectedAttributeStart
957 )
958 FullyConnectedAttribute.AddInputZp = (
959 FullyConnectedAttribute.FullyConnectedAttributeAddInputZp
960 )
961 FullyConnectedAttribute.AddWeightZp = (
962 FullyConnectedAttribute.FullyConnectedAttributeAddWeightZp
963 )
James Ward485a11d2022-08-05 13:48:37 +0100964 FullyConnectedAttribute.AddAccumDtype = (
965 FullyConnectedAttribute.FullyConnectedAttributeAddAccumDtype
966 )
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000967 FullyConnectedAttribute.End = (
968 FullyConnectedAttribute.FullyConnectedAttributeEnd
969 )
970 from tosa import MatMulAttribute
Eric Kunzeae906de2022-05-30 22:40:47 -0700971
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000972 if not hasattr(MatMulAttribute, "Start"):
973 MatMulAttribute.Start = MatMulAttribute.MatMulAttributeStart
974 MatMulAttribute.AddAZp = MatMulAttribute.MatMulAttributeAddAZp
975 MatMulAttribute.AddBZp = MatMulAttribute.MatMulAttributeAddBZp
James Ward485a11d2022-08-05 13:48:37 +0100976 MatMulAttribute.AddAccumDtype = MatMulAttribute.MatMulAttributeAddAccumDtype
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000977 MatMulAttribute.End = MatMulAttribute.MatMulAttributeEnd
978 from tosa import PoolAttribute
979
980 if not hasattr(PoolAttribute, "Start"):
981 PoolAttribute.Start = PoolAttribute.PoolAttributeStart
982 PoolAttribute.AddPad = PoolAttribute.PoolAttributeAddPad
983 PoolAttribute.StartPadVector = PoolAttribute.PoolAttributeStartPadVector
984 PoolAttribute.AddKernel = PoolAttribute.PoolAttributeAddKernel
985 PoolAttribute.StartKernelVector = (
986 PoolAttribute.PoolAttributeStartKernelVector
987 )
988 PoolAttribute.AddStride = PoolAttribute.PoolAttributeAddStride
989 PoolAttribute.StartStrideVector = (
990 PoolAttribute.PoolAttributeStartStrideVector
991 )
James Ward485a11d2022-08-05 13:48:37 +0100992 PoolAttribute.AddAccumDtype = PoolAttribute.PoolAttributeAddAccumDtype
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000993 PoolAttribute.AddInputZp = PoolAttribute.PoolAttributeAddInputZp
994 PoolAttribute.AddOutputZp = PoolAttribute.PoolAttributeAddOutputZp
995 PoolAttribute.End = PoolAttribute.PoolAttributeEnd
Eric Kunzeae906de2022-05-30 22:40:47 -0700996 from tosa import MulAttribute
997
998 if not hasattr(MulAttribute, "Start"):
999 MulAttribute.Start = MulAttribute.MulAttributeStart
1000 MulAttribute.AddShift = MulAttribute.MulAttributeAddShift
1001 MulAttribute.End = MulAttribute.MulAttributeEnd
1002 from tosa import PadAttribute
1003
1004 if not hasattr(PadAttribute, "Start"):
1005 PadAttribute.Start = PadAttribute.PadAttributeStart
1006 PadAttribute.AddPadding = PadAttribute.PadAttributeAddPadding
1007 PadAttribute.StartPaddingVector = (
1008 PadAttribute.PadAttributeStartPaddingVector
1009 )
1010 PadAttribute.AddPadConstInt = PadAttribute.PadAttributeAddPadConstInt
1011 PadAttribute.AddPadConstFp = PadAttribute.PadAttributeAddPadConstFp
1012 PadAttribute.End = PadAttribute.PadAttributeEnd
Eric Kunzeae906de2022-05-30 22:40:47 -07001013 from tosa import PoolAttribute
1014
1015 if not hasattr(PoolAttribute, "Start"):
1016 PoolAttribute.Start = PoolAttribute.PoolAttributeStart
1017 PoolAttribute.AddPad = PoolAttribute.PoolAttributeAddPad
1018 PoolAttribute.StartPadVector = PoolAttribute.PoolAttributeStartPadVector
1019 PoolAttribute.AddKernel = PoolAttribute.PoolAttributeAddKernel
1020 PoolAttribute.StartKernelVector = (
1021 PoolAttribute.PoolAttributeStartKernelVector
1022 )
1023 PoolAttribute.AddStride = PoolAttribute.PoolAttributeAddStride
1024 PoolAttribute.StartStrideVector = (
1025 PoolAttribute.PoolAttributeStartStrideVector
1026 )
James Ward485a11d2022-08-05 13:48:37 +01001027 PoolAttribute.AddAccumDtype = PoolAttribute.PoolAttributeAddAccumDtype
Eric Kunzebdcc3fe2022-06-07 05:17:37 +00001028 PoolAttribute.AddInputZp = PoolAttribute.PoolAttributeAddInputZp
1029 PoolAttribute.AddOutputZp = PoolAttribute.PoolAttributeAddOutputZp
Eric Kunzeae906de2022-05-30 22:40:47 -07001030 PoolAttribute.End = PoolAttribute.PoolAttributeEnd
1031 from tosa import RescaleAttribute
1032
1033 if not hasattr(RescaleAttribute, "Start"):
1034 RescaleAttribute.Start = RescaleAttribute.RescaleAttributeStart
1035 RescaleAttribute.AddInputZp = RescaleAttribute.RescaleAttributeAddInputZp
1036 RescaleAttribute.AddOutputZp = RescaleAttribute.RescaleAttributeAddOutputZp
1037 RescaleAttribute.AddMultiplier = (
1038 RescaleAttribute.RescaleAttributeAddMultiplier
1039 )
1040 RescaleAttribute.StartMultiplierVector = (
1041 RescaleAttribute.RescaleAttributeStartMultiplierVector
1042 )
1043 RescaleAttribute.AddShift = RescaleAttribute.RescaleAttributeAddShift
1044 RescaleAttribute.StartShiftVector = (
1045 RescaleAttribute.RescaleAttributeStartShiftVector
1046 )
1047 RescaleAttribute.AddScale32 = RescaleAttribute.RescaleAttributeAddScale32
1048 RescaleAttribute.AddDoubleRound = (
1049 RescaleAttribute.RescaleAttributeAddDoubleRound
1050 )
1051 RescaleAttribute.AddPerChannel = (
1052 RescaleAttribute.RescaleAttributeAddPerChannel
1053 )
1054 RescaleAttribute.End = RescaleAttribute.RescaleAttributeEnd
1055 from tosa import ReshapeAttribute
1056
1057 if not hasattr(ReshapeAttribute, "Start"):
1058 ReshapeAttribute.Start = ReshapeAttribute.ReshapeAttributeStart
1059 ReshapeAttribute.AddNewShape = ReshapeAttribute.ReshapeAttributeAddNewShape
1060 ReshapeAttribute.StartNewShapeVector = (
1061 ReshapeAttribute.ReshapeAttributeStartNewShapeVector
1062 )
1063 ReshapeAttribute.End = ReshapeAttribute.ReshapeAttributeEnd
1064 from tosa import ResizeAttribute
1065
1066 if not hasattr(ResizeAttribute, "Start"):
1067 ResizeAttribute.Start = ResizeAttribute.ResizeAttributeStart
TatWai Chong49b1ca62022-06-10 01:49:13 -07001068 ResizeAttribute.AddScale = ResizeAttribute.ResizeAttributeAddScale
1069 ResizeAttribute.StartScaleVector = (
1070 ResizeAttribute.ResizeAttributeStartScaleVector
Eric Kunzeae906de2022-05-30 22:40:47 -07001071 )
1072 ResizeAttribute.AddOffset = ResizeAttribute.ResizeAttributeAddOffset
1073 ResizeAttribute.StartOffsetVector = (
1074 ResizeAttribute.ResizeAttributeStartOffsetVector
1075 )
TatWai Chong49b1ca62022-06-10 01:49:13 -07001076 ResizeAttribute.AddBorder = ResizeAttribute.ResizeAttributeAddBorder
1077 ResizeAttribute.StartBorderVector = (
1078 ResizeAttribute.ResizeAttributeStartBorderVector
Eric Kunzeae906de2022-05-30 22:40:47 -07001079 )
1080 ResizeAttribute.AddMode = ResizeAttribute.ResizeAttributeAddMode
1081 ResizeAttribute.End = ResizeAttribute.ResizeAttributeEnd
1082 from tosa import SliceAttribute
1083
1084 if not hasattr(SliceAttribute, "Start"):
1085 SliceAttribute.Start = SliceAttribute.SliceAttributeStart
1086 SliceAttribute.AddStart = SliceAttribute.SliceAttributeAddStart
1087 SliceAttribute.StartStartVector = (
1088 SliceAttribute.SliceAttributeStartStartVector
1089 )
1090 SliceAttribute.AddSize = SliceAttribute.SliceAttributeAddSize
1091 SliceAttribute.StartSizeVector = (
1092 SliceAttribute.SliceAttributeStartSizeVector
1093 )
1094 SliceAttribute.End = SliceAttribute.SliceAttributeEnd
1095 from tosa import TableAttribute
1096
1097 if not hasattr(TableAttribute, "Start"):
1098 TableAttribute.Start = TableAttribute.TableAttributeStart
1099 TableAttribute.AddTable = TableAttribute.TableAttributeAddTable
1100 TableAttribute.StartTableVector = (
1101 TableAttribute.TableAttributeStartTableVector
1102 )
1103 TableAttribute.End = TableAttribute.TableAttributeEnd
1104 from tosa import TileAttribute
1105
1106 if not hasattr(TileAttribute, "Start"):
1107 TileAttribute.Start = TileAttribute.TileAttributeStart
1108 TileAttribute.AddMultiples = TileAttribute.TileAttributeAddMultiples
1109 TileAttribute.StartMultiplesVector = (
1110 TileAttribute.TileAttributeStartMultiplesVector
1111 )
1112 TileAttribute.End = TileAttribute.TileAttributeEnd
1113 from tosa import TosaBasicBlock
1114
1115 if not hasattr(TosaBasicBlock, "Start"):
1116 TosaBasicBlock.Start = TosaBasicBlock.TosaBasicBlockStart
1117 TosaBasicBlock.AddName = TosaBasicBlock.TosaBasicBlockAddName
1118 TosaBasicBlock.AddOperators = TosaBasicBlock.TosaBasicBlockAddOperators
1119 TosaBasicBlock.StartOperatorsVector = (
1120 TosaBasicBlock.TosaBasicBlockStartOperatorsVector
1121 )
1122 TosaBasicBlock.AddTensors = TosaBasicBlock.TosaBasicBlockAddTensors
1123 TosaBasicBlock.StartTensorsVector = (
1124 TosaBasicBlock.TosaBasicBlockStartTensorsVector
1125 )
1126 TosaBasicBlock.AddInputs = TosaBasicBlock.TosaBasicBlockAddInputs
1127 TosaBasicBlock.StartInputsVector = (
1128 TosaBasicBlock.TosaBasicBlockStartInputsVector
1129 )
1130 TosaBasicBlock.AddOutputs = TosaBasicBlock.TosaBasicBlockAddOutputs
1131 TosaBasicBlock.StartOutputsVector = (
1132 TosaBasicBlock.TosaBasicBlockStartOutputsVector
1133 )
1134 TosaBasicBlock.End = TosaBasicBlock.TosaBasicBlockEnd
1135 from tosa import TosaGraph
1136
1137 if not hasattr(TosaGraph, "Start"):
1138 TosaGraph.Start = TosaGraph.TosaGraphStart
1139 TosaGraph.AddVersion = TosaGraph.TosaGraphAddVersion
Jerry Ge1eb85042023-01-06 14:19:14 -08001140 TosaGraph.AddRegions = TosaGraph.TosaGraphAddRegions
1141 TosaGraph.StartRegionsVector = TosaGraph.TosaGraphStartRegionsVector
Eric Kunzeae906de2022-05-30 22:40:47 -07001142 TosaGraph.End = TosaGraph.TosaGraphEnd
1143 from tosa import TosaOperator
1144
1145 if not hasattr(TosaOperator, "Start"):
1146 TosaOperator.Start = TosaOperator.TosaOperatorStart
1147 TosaOperator.AddOp = TosaOperator.TosaOperatorAddOp
1148 TosaOperator.AddAttributeType = TosaOperator.TosaOperatorAddAttributeType
1149 TosaOperator.AddAttribute = TosaOperator.TosaOperatorAddAttribute
1150 TosaOperator.AddInputs = TosaOperator.TosaOperatorAddInputs
1151 TosaOperator.StartInputsVector = TosaOperator.TosaOperatorStartInputsVector
1152 TosaOperator.AddOutputs = TosaOperator.TosaOperatorAddOutputs
1153 TosaOperator.StartOutputsVector = (
1154 TosaOperator.TosaOperatorStartOutputsVector
1155 )
Eric Kunzeae906de2022-05-30 22:40:47 -07001156 TosaOperator.End = TosaOperator.TosaOperatorEnd
1157 from tosa import TosaTensor
1158
1159 if not hasattr(TosaTensor, "Start"):
1160 TosaTensor.Start = TosaTensor.TosaTensorStart
1161 TosaTensor.AddName = TosaTensor.TosaTensorAddName
1162 TosaTensor.AddShape = TosaTensor.TosaTensorAddShape
1163 TosaTensor.StartShapeVector = TosaTensor.TosaTensorStartShapeVector
1164 TosaTensor.AddType = TosaTensor.TosaTensorAddType
1165 TosaTensor.AddData = TosaTensor.TosaTensorAddData
1166 TosaTensor.StartDataVector = TosaTensor.TosaTensorStartDataVector
1167 TosaTensor.End = TosaTensor.TosaTensorEnd
1168 from tosa import TransposeAttribute
1169
1170 if not hasattr(TransposeAttribute, "Start"):
1171 TransposeAttribute.Start = TransposeAttribute.TransposeAttributeStart
1172 TransposeAttribute.AddPerms = TransposeAttribute.TransposeAttributeAddPerms
1173 TransposeAttribute.StartPermsVector = (
1174 TransposeAttribute.TransposeAttributeStartPermsVector
1175 )
1176 TransposeAttribute.End = TransposeAttribute.TransposeAttributeEnd
1177 from tosa import TransposeConvAttribute
1178
1179 if not hasattr(TransposeConvAttribute, "Start"):
1180 TransposeConvAttribute.Start = (
1181 TransposeConvAttribute.TransposeConvAttributeStart
1182 )
Eric Kunze4c3537d2022-06-13 17:21:48 -07001183 TransposeConvAttribute.AddOutPad = (
1184 TransposeConvAttribute.TransposeConvAttributeAddOutPad
Eric Kunzeae906de2022-05-30 22:40:47 -07001185 )
Eric Kunze4c3537d2022-06-13 17:21:48 -07001186 TransposeConvAttribute.StartOutPadVector = (
1187 TransposeConvAttribute.TransposeConvAttributeStartOutPadVector
Eric Kunzeae906de2022-05-30 22:40:47 -07001188 )
1189 TransposeConvAttribute.AddStride = (
1190 TransposeConvAttribute.TransposeConvAttributeAddStride
1191 )
1192 TransposeConvAttribute.StartStrideVector = (
1193 TransposeConvAttribute.TransposeConvAttributeStartStrideVector
1194 )
Eric Kunzeae906de2022-05-30 22:40:47 -07001195 TransposeConvAttribute.AddOutputShape = (
1196 TransposeConvAttribute.TransposeConvAttributeAddOutputShape
1197 )
1198 TransposeConvAttribute.StartOutputShapeVector = (
1199 TransposeConvAttribute.TransposeConvAttributeStartOutputShapeVector
1200 )
Eric Kunzebdcc3fe2022-06-07 05:17:37 +00001201 TransposeConvAttribute.AddInputZp = (
1202 TransposeConvAttribute.TransposeConvAttributeAddInputZp
1203 )
1204 TransposeConvAttribute.AddWeightZp = (
1205 TransposeConvAttribute.TransposeConvAttributeAddWeightZp
1206 )
James Ward485a11d2022-08-05 13:48:37 +01001207 TransposeConvAttribute.AddAccumDtype = (
1208 TransposeConvAttribute.TransposeConvAttributeAddAccumDtype
1209 )
Eric Kunzeae906de2022-05-30 22:40:47 -07001210 TransposeConvAttribute.End = (
1211 TransposeConvAttribute.TransposeConvAttributeEnd
1212 )
Eric Kunzeae906de2022-05-30 22:40:47 -07001213 from tosa import Version
1214
1215 if not hasattr(Version, "Start"):
1216 Version.Start = Version.VersionStart
1217 Version.Add_major = Version.VersionAdd_major
1218 Version.Add_minor = Version.VersionAdd_minor
1219 Version.Add_patch = Version.VersionAdd_patch
1220 Version.Add_draft = Version.VersionAdd_draft
1221 Version.End = Version.VersionEnd
Eric Kunzebdcc3fe2022-06-07 05:17:37 +00001222 from tosa import MatMulAttribute
1223
1224 if not hasattr(MatMulAttribute, "Start"):
1225 MatMulAttribute.Start = MatMulAttribute.MatMulAttributeStart
1226 MatMulAttribute.AddAZp = MatMulAttribute.MatMulAttributeAddAZp
1227 MatMulAttribute.AddBZp = MatMulAttribute.MatMulAttributeAddBZp
1228 MatMulAttribute.End = MatMulAttribute.MatMulAttributeEnd
1229 from tosa import FullyConnectedAttribute
1230
1231 if not hasattr(FullyConnectedAttribute, "Start"):
1232 FullyConnectedAttribute.Start = (
1233 FullyConnectedAttribute.FullyConnectedAttributeStart
1234 )
1235 FullyConnectedAttribute.AddInputZp = (
1236 FullyConnectedAttribute.FullyConnectedAttributeAddInputZp
1237 )
1238 FullyConnectedAttribute.AddWeightZp = (
1239 FullyConnectedAttribute.FullyConnectedAttributeAddWeightZp
1240 )
1241 FullyConnectedAttribute.End = (
1242 FullyConnectedAttribute.FullyConnectedAttributeEnd
1243 )
1244 from tosa import NegateAttribute
1245
1246 if not hasattr(NegateAttribute, "Start"):
1247 NegateAttribute.Start = NegateAttribute.NegateAttributeStart
1248 NegateAttribute.AddInput1Zp = NegateAttribute.NegateAttributeAddInput1Zp
1249 NegateAttribute.AddOutputZp = NegateAttribute.NegateAttributeAddOutputZp
1250 NegateAttribute.End = NegateAttribute.NegateAttributeEnd
Eric Kunzeae906de2022-05-30 22:40:47 -07001251 from tosa import WhileLoopAttribute
1252
1253 if not hasattr(WhileLoopAttribute, "Start"):
1254 WhileLoopAttribute.Start = WhileLoopAttribute.WhileLoopAttributeStart
1255 WhileLoopAttribute.AddCondBranch = (
1256 WhileLoopAttribute.WhileLoopAttributeAddCondBranch
1257 )
1258 WhileLoopAttribute.AddBodyBranch = (
1259 WhileLoopAttribute.WhileLoopAttributeAddBodyBranch
1260 )
1261 WhileLoopAttribute.End = WhileLoopAttribute.WhileLoopAttributeEnd