blob: e7bcef1b6f3864aa5e76bf84825f45627d3f1d27 [file] [log] [blame]
Jeremy Johnson9b225172021-12-14 16:34:47 +00001# Copyright (c) 2020-2022, ARM Limited.
Kevin Chengfea5a372021-10-11 18:38:47 +00002#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
Kevin Chengfea5a372021-10-11 18:38:47 +000015import os
Kevin Chengfea5a372021-10-11 18:38:47 +000016import json
17import flatbuffers
18import numpy as np
19import struct
Jeremy Johnson9b225172021-12-14 16:34:47 +000020from enum import IntEnum, unique
Kevin Chengfea5a372021-10-11 18:38:47 +000021from tosa import (
22 TosaGraph,
23 TosaBasicBlock,
24 TosaTensor,
25 TosaOperator,
Kevin Chengfea5a372021-10-11 18:38:47 +000026 Version,
27)
Jeremy Johnson9b225172021-12-14 16:34:47 +000028import tosa.DType as TosaDType
29import tosa.Op as TosaOp
Kevin Chengfea5a372021-10-11 18:38:47 +000030
Kevin Chenge6563f52021-10-20 12:12:02 -070031# Keep version number in sync with the version default value with schema/tosa.fbs
Kevin Chengb97cb1d2021-10-14 11:53:39 -070032TOSA_VERSION_MAJOR = 0
Eric Kunze9e2e0bb2022-12-05 23:23:54 +000033TOSA_VERSION_MINOR = 50
Kevin Chengb97cb1d2021-10-14 11:53:39 -070034TOSA_VERSION_PATCH = 0
Eric Kunze9e2e0bb2022-12-05 23:23:54 +000035TOSA_VERSION_DRAFT = False
Jeremy Johnson9b225172021-12-14 16:34:47 +000036TOSA_VERSION = [
37 TOSA_VERSION_MAJOR,
38 TOSA_VERSION_MINOR,
39 TOSA_VERSION_PATCH,
40 TOSA_VERSION_DRAFT,
41]
Eric Kunzee6596402022-06-09 21:27:36 +000042
43# File identifier needs to be kept in sync with schema/tosa.fbs
44TOSA_GRAPH_IDENTIFIER = b"\x54\x4F\x53\x41"
45
Kevin Chengfea5a372021-10-11 18:38:47 +000046# With the way flatc generates its python types, there is no programatic way
47# to get string names for the integer types. Manually maintain a string table
48# here.
Jeremy Johnson9b225172021-12-14 16:34:47 +000049DType = TosaDType.DType()
Kevin Chengfea5a372021-10-11 18:38:47 +000050DTypeNames = [
51 "UNKNOWN",
52 "BOOL",
53 "UINT8",
54 "INT4",
55 "INT8",
56 "INT16",
57 "INT32",
58 "INT48",
Jeremy Johnsone1072a92022-09-27 12:44:11 +010059 "FP32",
Jeremy Johnson41027732022-05-25 17:52:29 +010060 "UINT16",
James Ward485a11d2022-08-05 13:48:37 +010061 "FP16",
James Ward34a62792022-10-18 17:27:40 +010062 "BF16",
Kevin Chengfea5a372021-10-11 18:38:47 +000063]
64
65ByteMask = np.uint64(0xFF)
66
67
68def dtype_str_to_val(name):
69
70 for i in range(len(DTypeNames)):
71 if name.casefold() == DTypeNames[i].casefold():
72 return i
73 raise Exception("Unable to parse DType name {}".format(name))
74
75
76class TosaSerializerUnion:
77 """This class handles encapsulating and serializing union types into flatbuffers"""
78
79 def __init__(self):
80
Jeremy Johnson9b225172021-12-14 16:34:47 +000081 # A tuple of the start and end functions.
82 # Set by the options constructors below
Kevin Chengfea5a372021-10-11 18:38:47 +000083 self.optFcns = None
84
Jeremy Johnson9b225172021-12-14 16:34:47 +000085 # The type from the tosa.Options enumeration.
86 # Set by the options constructors below.
Kevin Chengfea5a372021-10-11 18:38:47 +000087 self.utype = None
88
89 # Each of these lists is a tuple of the add function and the
90 # value being added. Set by the options constructors below.
91 self.ints = []
92 self.bools = []
93 self.floats = []
94 self.strings = []
TatWai Chong49b1ca62022-06-10 01:49:13 -070095 self.int16vecs = []
Kevin Chengfea5a372021-10-11 18:38:47 +000096 self.intvecs = []
97 self.fpvecs = []
98
99 def serialize(self, builder):
100
101 # We have to build strings and vectors first
102 strList = []
103 intVecList = []
104 fpVecList = []
105
106 for fcn, val in self.strings:
107 strList.append((fcn, builder.CreateString(val)))
108
109 for fcn, val in self.intvecs:
110 intVecList.append((fcn, TosaSerializer.serializeInt32Vec(builder, val)))
111
TatWai Chong49b1ca62022-06-10 01:49:13 -0700112 for fcn, val in self.int16vecs:
113 intVecList.append((fcn, TosaSerializer.serializeInt16Vec(builder, val)))
114
Kevin Chengfea5a372021-10-11 18:38:47 +0000115 for fcn, val in self.fpvecs:
116 fpVecList.append((fcn, TosaSerializer.serializeFpVec(builder, val)))
117
118 startFcn, endFcn = self.optFcns
119
120 # Then serialize the options object from the list of primitives and
121 # other serialized values
122 startFcn(builder)
123 for fcn, val in self.ints:
124 fcn(builder, val)
125
126 for fcn, val in self.bools:
127 fcn(builder, val)
128
129 for fcn, val in self.floats:
130 fcn(builder, val)
131
132 for fcn, val in strList:
133 fcn(builder, val)
134
135 for fcn, val in intVecList:
136 fcn(builder, val)
137
138 for fcn, val in fpVecList:
139 fcn(builder, val)
140
141 return endFcn(builder)
142
143
144class TosaSerializerAttribute(TosaSerializerUnion):
145 """This class handles encapsulating all of the enumerated types for attributes"""
146
147 def __init__(self):
148 super().__init__()
149
James Ward485a11d2022-08-05 13:48:37 +0100150 def PoolAttribute(
151 self,
152 kernel,
153 stride,
154 pad,
155 input_zp,
156 output_zp,
157 accum_dtype,
158 ):
Kevin Chengfea5a372021-10-11 18:38:47 +0000159 from tosa import PoolAttribute as a, Attribute
160
161 self.utype = Attribute.Attribute().PoolAttribute
162
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800163 self.optFcns = (a.Start, a.End)
TatWai Chong7be71652022-05-10 17:26:20 -0700164 self.intvecs.append((a.AddPad, pad))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800165 self.intvecs.append((a.AddKernel, kernel))
166 self.intvecs.append((a.AddStride, stride))
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000167 self.ints.append((a.AddInputZp, input_zp))
168 self.ints.append((a.AddOutputZp, output_zp))
James Ward485a11d2022-08-05 13:48:37 +0100169 self.ints.append((a.AddAccumDtype, accum_dtype))
Kevin Chengfea5a372021-10-11 18:38:47 +0000170
James Ward485a11d2022-08-05 13:48:37 +0100171 def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp, accum_dtype):
Kevin Chengfea5a372021-10-11 18:38:47 +0000172 from tosa import ConvAttribute as a, Attribute
173
174 self.utype = Attribute.Attribute().ConvAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800175 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000176
TatWai Chong7be71652022-05-10 17:26:20 -0700177 self.intvecs.append((a.AddPad, pad))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800178 self.intvecs.append((a.AddStride, stride))
179 self.intvecs.append((a.AddDilation, dilation))
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000180 self.ints.append((a.AddInputZp, input_zp))
181 self.ints.append((a.AddWeightZp, weight_zp))
James Ward485a11d2022-08-05 13:48:37 +0100182 self.ints.append((a.AddAccumDtype, accum_dtype))
Kevin Chengfea5a372021-10-11 18:38:47 +0000183
James Ward485a11d2022-08-05 13:48:37 +0100184 def TransposeConvAttribute(
185 self, outpad, stride, output_shape, input_zp, weight_zp, accum_dtype
186 ):
Kevin Chengfea5a372021-10-11 18:38:47 +0000187 from tosa import TransposeConvAttribute as a, Attribute
188
189 self.utype = Attribute.Attribute().TransposeConvAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800190 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000191
Eric Kunze4c3537d2022-06-13 17:21:48 -0700192 self.intvecs.append((a.AddOutPad, outpad))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800193 self.intvecs.append((a.AddStride, stride))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800194 self.intvecs.append((a.AddOutputShape, output_shape))
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000195 self.ints.append((a.AddInputZp, input_zp))
196 self.ints.append((a.AddWeightZp, weight_zp))
James Ward485a11d2022-08-05 13:48:37 +0100197 self.ints.append((a.AddAccumDtype, accum_dtype))
Kevin Chengfea5a372021-10-11 18:38:47 +0000198
Kevin Cheng38d214c2021-10-15 15:49:19 -0700199 def PadAttribute(self, padding, pad_const_int, pad_const_fp):
200 from tosa import PadAttribute as a, Attribute
Kevin Chengfea5a372021-10-11 18:38:47 +0000201
Kevin Cheng38d214c2021-10-15 15:49:19 -0700202 self.utype = Attribute.Attribute().PadAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800203 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000204
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800205 self.intvecs.append((a.AddPadding, padding))
206 self.ints.append((a.AddPadConstInt, pad_const_int))
207 self.floats.append((a.AddPadConstFp, pad_const_fp))
Kevin Chengfea5a372021-10-11 18:38:47 +0000208
209 def AxisAttribute(self, axis):
210 from tosa import AxisAttribute as a, Attribute
211
212 self.utype = Attribute.Attribute().AxisAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800213 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000214
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800215 self.ints.append((a.AddAxis, axis))
Kevin Chengfea5a372021-10-11 18:38:47 +0000216
TatWai Chong7be71652022-05-10 17:26:20 -0700217 def ReshapeAttribute(self, new_shape):
Kevin Chengfea5a372021-10-11 18:38:47 +0000218 from tosa import ReshapeAttribute as a, Attribute
219
220 self.utype = Attribute.Attribute().ReshapeAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800221 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000222
TatWai Chong7be71652022-05-10 17:26:20 -0700223 self.intvecs.append((a.AddNewShape, new_shape))
Kevin Chengfea5a372021-10-11 18:38:47 +0000224
TatWai Chong7be71652022-05-10 17:26:20 -0700225 def SliceAttribute(self, start, size):
Kevin Chengfea5a372021-10-11 18:38:47 +0000226 from tosa import SliceAttribute as a, Attribute
227
228 self.utype = Attribute.Attribute().SliceAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800229 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000230
TatWai Chong7be71652022-05-10 17:26:20 -0700231 self.intvecs.append((a.AddStart, start))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800232 self.intvecs.append((a.AddSize, size))
Kevin Chengfea5a372021-10-11 18:38:47 +0000233
234 def TileAttribute(self, multiples):
235 from tosa import TileAttribute as a, Attribute
236
237 self.utype = Attribute.Attribute().TileAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800238 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000239
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800240 self.intvecs.append((a.AddMultiples, multiples))
Kevin Chengfea5a372021-10-11 18:38:47 +0000241
TatWai Chong49b1ca62022-06-10 01:49:13 -0700242 def ResizeAttribute(self, scale, offset, border, mode):
Kevin Chengfea5a372021-10-11 18:38:47 +0000243 from tosa import ResizeAttribute as a, Attribute
244
245 self.utype = Attribute.Attribute().ResizeAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800246 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000247
TatWai Chong49b1ca62022-06-10 01:49:13 -0700248 self.int16vecs.append((a.AddScale, scale))
249 self.int16vecs.append((a.AddOffset, offset))
250 self.int16vecs.append((a.AddBorder, border))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800251 self.ints.append((a.AddMode, mode))
Kevin Chengfea5a372021-10-11 18:38:47 +0000252
253 def ClampAttribute(self, minint, maxint, minfp, maxfp):
254 from tosa import ClampAttribute as a, Attribute
255
256 self.utype = Attribute.Attribute().ClampAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800257 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000258
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800259 self.ints.append((a.AddMinInt, minint))
260 self.ints.append((a.AddMaxInt, maxint))
Kevin Chengfea5a372021-10-11 18:38:47 +0000261
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800262 self.ints.append((a.AddMinFp, minfp))
263 self.ints.append((a.AddMaxFp, maxfp))
Kevin Chengfea5a372021-10-11 18:38:47 +0000264
265 def RescaleAttribute(
266 self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel
267 ):
268 from tosa import RescaleAttribute as a, Attribute
269
270 self.utype = Attribute.Attribute().RescaleAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800271 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000272
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800273 self.ints.append((a.AddInputZp, input_zp))
274 self.ints.append((a.AddOutputZp, output_zp))
275 self.intvecs.append((a.AddMultiplier, multiplier))
276 self.intvecs.append((a.AddShift, shift))
277 self.bools.append((a.AddScale32, scale32))
278 self.bools.append((a.AddDoubleRound, double_round))
279 self.bools.append((a.AddPerChannel, per_channel))
Kevin Chengfea5a372021-10-11 18:38:47 +0000280
281 def MulAttribute(self, shift):
282 from tosa import MulAttribute as a, Attribute
283
284 self.utype = Attribute.Attribute().MulAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800285 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000286
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800287 self.ints.append((a.AddShift, shift))
Kevin Chengfea5a372021-10-11 18:38:47 +0000288
289 def ArithmeticRightShiftAttribute(self, round):
290 from tosa import ArithmeticRightShiftAttribute as a, Attribute
291
292 self.utype = Attribute.Attribute().ArithmeticRightShiftAttribute
293 self.optFcns = (
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800294 a.Start,
295 a.End,
Kevin Chengfea5a372021-10-11 18:38:47 +0000296 )
297
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800298 self.bools.append((a.AddRound, round))
Kevin Chengfea5a372021-10-11 18:38:47 +0000299
Kevin Chengfea5a372021-10-11 18:38:47 +0000300 def CondIfAttribute(self, then_branch, else_branch):
301 from tosa import CondIfAttribute as a, Attribute
302
303 self.utype = Attribute.Attribute().CondIfAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800304 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000305
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800306 self.strings.append((a.AddThenBranch, then_branch))
307 self.strings.append((a.AddElseBranch, else_branch))
Kevin Chengfea5a372021-10-11 18:38:47 +0000308
309 def WhileLoopAttribute(self, cond_branch, body_branch):
310 from tosa import WhileLoopAttribute as a, Attribute
311
312 self.utype = Attribute.Attribute().WhileLoopAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800313 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000314
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800315 self.strings.append((a.AddCondBranch, cond_branch))
316 self.strings.append((a.AddBodyBranch, body_branch))
Kevin Chengfea5a372021-10-11 18:38:47 +0000317
TatWai Chong7be71652022-05-10 17:26:20 -0700318 def TransposeAttribute(self, perms):
Kevin Cheng38d214c2021-10-15 15:49:19 -0700319 from tosa import TransposeAttribute as a, Attribute
320
321 self.utype = Attribute.Attribute().TransposeAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800322 self.optFcns = (a.Start, a.End)
Kevin Cheng38d214c2021-10-15 15:49:19 -0700323
TatWai Chong7be71652022-05-10 17:26:20 -0700324 self.intvecs.append((a.AddPerms, perms))
Kevin Cheng38d214c2021-10-15 15:49:19 -0700325
326 def TableAttribute(self, table):
327 from tosa import TableAttribute as a, Attribute
328
329 self.utype = Attribute.Attribute().TableAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800330 self.optFcns = (a.Start, a.End)
Kevin Cheng38d214c2021-10-15 15:49:19 -0700331
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800332 self.intvecs.append((a.AddTable, table))
Kevin Chengfea5a372021-10-11 18:38:47 +0000333
James Ward485a11d2022-08-05 13:48:37 +0100334 def MatMulAttribute(self, A_zp, B_zp, accum_dtype):
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000335 from tosa import MatMulAttribute as a, Attribute
Jeremy Johnson9b225172021-12-14 16:34:47 +0000336
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000337 self.utype = Attribute.Attribute().MatMulAttribute
338 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000339
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000340 self.ints.append((a.AddAZp, A_zp))
341 self.ints.append((a.AddBZp, B_zp))
James Ward485a11d2022-08-05 13:48:37 +0100342 self.ints.append((a.AddAccumDtype, accum_dtype))
Kevin Chengfea5a372021-10-11 18:38:47 +0000343
James Ward485a11d2022-08-05 13:48:37 +0100344 def FullyConnectedAttribute(self, input_zp, weight_zp, accum_dtype):
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000345 from tosa import FullyConnectedAttribute as a, Attribute
Kevin Chengfea5a372021-10-11 18:38:47 +0000346
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000347 self.utype = Attribute.Attribute().FullyConnectedAttribute
348 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000349
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000350 self.ints.append((a.AddInputZp, input_zp))
351 self.ints.append((a.AddWeightZp, weight_zp))
James Ward485a11d2022-08-05 13:48:37 +0100352 self.ints.append((a.AddAccumDtype, accum_dtype))
Kevin Chengfea5a372021-10-11 18:38:47 +0000353
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000354 def NegateAttribute(self, input1_zp, output_zp):
355 from tosa import NegateAttribute as a, Attribute
Kevin Chengfea5a372021-10-11 18:38:47 +0000356
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000357 self.utype = Attribute.Attribute().NegateAttribute
358 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000359
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000360 self.ints.append((a.AddInput1Zp, input1_zp))
361 self.ints.append((a.AddOutputZp, output_zp))
Kevin Chengfea5a372021-10-11 18:38:47 +0000362
363
364class TosaSerializerTensor:
365 def __init__(
366 self,
367 name,
368 shape,
369 dtype,
370 data=None,
371 placeholderFilename=None,
372 ):
373 self.name = name
374
375 if isinstance(shape, np.ndarray):
376 shape = shape.astype(int).tolist()
377 shape = list(map(int, shape))
378
379 self.shape = shape
380 self.dtype = dtype
381
James Ward34a62792022-10-18 17:27:40 +0100382 if dtype == DType.FP32 or dtype == DType.BF16:
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100383 fntype = np.float32
James Ward485a11d2022-08-05 13:48:37 +0100384 elif dtype == DType.FP16:
385 fntype = np.float16
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100386 else:
387 fntype = int
388
Kevin Chengfea5a372021-10-11 18:38:47 +0000389 if isinstance(data, np.ndarray):
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100390 data = data.flatten().astype(fntype).tolist()
391 data = list(map(fntype, data))
Kevin Chengfea5a372021-10-11 18:38:47 +0000392 self.data = data
393 elif isinstance(data, list):
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100394 data = list(map(fntype, data))
Kevin Chengfea5a372021-10-11 18:38:47 +0000395 self.data = data
396 else:
397 self.data = None
398
399 # Filename for placeholder tensors. These get generated by the test generation
Jeremy Johnson9b225172021-12-14 16:34:47 +0000400 # process and are written to disk, but are considered input tensors by the
401 # network so they do not appear in the TOSA serialiazation. However, if we
402 # want to form a unit test around these input tensors, we can get the filename
403 # from here.
Kevin Chengfea5a372021-10-11 18:38:47 +0000404 self.placeholderFilename = placeholderFilename
405
406 def __str__(self):
407 str = "TosaSerializerTensor name: {} shape: {} dtype: {}".format(
408 self.name,
409 self.shape,
410 DTypeNames[self.dtype],
411 )
412 return str
413
414 def setDtype(self, dtype):
415 self.dtype = dtype
416
417 def serialize(self, builder):
418 fb_name = builder.CreateString(self.name)
419 fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
420 if self.data:
421 u8_data = list()
422 # little endianess
423 if self.dtype == DType.BOOL:
424 for val in self.data:
425 val_u8 = np.uint8(val)
426 u8_data.append(val_u8)
427 elif self.dtype == DType.INT4:
428 in_size = len(self.data)
429 out_size = (in_size + 1) // 2
430 for i in range(out_size):
431 val_0 = self.data[2 * i]
432 if (2 * i + 1) < in_size:
433 val_1 = self.data[2 * i + 1]
434 else:
435 val_1 = 0
436 val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4)
437 val_u8 = np.uint8(val_i8)
438 u8_data.append(val_u8)
439 elif self.dtype == DType.INT8:
440 for val in self.data:
441 val_u8 = np.uint8(val)
442 u8_data.append(val_u8)
443 elif self.dtype == DType.INT16:
444 for val in self.data:
445 val_u16 = np.uint16(val)
446 b0 = val_u16 & ByteMask
447 b1 = (val_u16 >> np.uint16(8)) & ByteMask
448 u8_data.extend([b0, b1])
449 elif self.dtype == DType.INT32:
450 for val in self.data:
451 val_u32 = np.uint32(val)
452 b0 = val_u32 & ByteMask
453 b1 = (val_u32 >> np.uint32(8)) & ByteMask
454 b2 = (val_u32 >> np.uint32(16)) & ByteMask
Kevin Cheng6b078ca2021-10-13 23:12:50 -0700455 b3 = (val_u32 >> np.uint32(24)) & ByteMask
Kevin Chengfea5a372021-10-11 18:38:47 +0000456 u8_data.extend([b0, b1, b2, b3])
457 elif self.dtype == DType.INT48:
458 for val in self.data:
459 val_u64 = np.uint64(val)
460 b0 = val_u64 & ByteMask
461 b1 = (val_u64 >> np.uint64(8)) & ByteMask
462 b2 = (val_u64 >> np.uint64(16)) & ByteMask
463 b3 = (val_u64 >> np.uint64(24)) & ByteMask
464 b4 = (val_u64 >> np.uint64(32)) & ByteMask
465 b5 = (val_u64 >> np.uint64(40)) & ByteMask
466 u8_data.extend([b0, b1, b2, b3, b4, b5])
James Ward485a11d2022-08-05 13:48:37 +0100467 elif self.dtype == DType.FP16:
468 np_arr = np.array(self.data, dtype=np.float16)
469 u8_data.extend(np_arr.view(np.uint8))
James Ward34a62792022-10-18 17:27:40 +0100470 elif self.dtype == DType.FP32 or self.dtype == DType.BF16:
Kevin Chengfea5a372021-10-11 18:38:47 +0000471 for val in self.data:
472 b = struct.pack("!f", val)
473 u8_data.extend([b[3], b[2], b[1], b[0]])
James Ward485a11d2022-08-05 13:48:37 +0100474 elif self.dtype == TosaDType.DType:
475 # Serialize DType enum data as uint8 bytes
476 for val in self.data:
477 np_arr = np.array(self.data, dtype=np.uint32)
478 u8_data.extend(np_arr.view(np.uint8))
Kevin Chengfea5a372021-10-11 18:38:47 +0000479 else:
480 raise Exception(
481 "unsupported data type {}".format(DTypeNames[self.dtype])
482 )
483 fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data)
484
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800485 TosaTensor.Start(builder)
486 TosaTensor.AddName(builder, fb_name)
487 TosaTensor.AddShape(builder, fb_shapes)
488 TosaTensor.AddType(builder, self.dtype)
Kevin Chengfea5a372021-10-11 18:38:47 +0000489 if self.data:
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800490 TosaTensor.AddData(builder, fb_data)
Kevin Chengfea5a372021-10-11 18:38:47 +0000491
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800492 return TosaTensor.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000493
494
495class TosaSerializerOperator:
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000496 def __init__(self, op, inputs, outputs, attributes=None):
Kevin Chengfea5a372021-10-11 18:38:47 +0000497 self.op = op
498 self.attributes = attributes
499 self.inputs = TosaSerializer.toList(inputs)
500 self.outputs = TosaSerializer.toList(outputs)
Kevin Chengfea5a372021-10-11 18:38:47 +0000501
502 def __str__(self):
503 str = "Op {}\n----\n".format(self.op)
504
505 for i in self.inputs:
506 str = str + " Input: {}\n".format(i)
507 for o in self.outputs:
508 str = str + " Output: {}\n".format(o)
509
510 return str
511
512 def serialize(self, builder):
513 fb_inputs = TosaSerializer.serializeStrVec(
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800514 builder, self.inputs, TosaOperator.StartInputsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000515 )
516 fb_outputs = TosaSerializer.serializeStrVec(
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800517 builder, self.outputs, TosaOperator.StartOutputsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000518 )
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000519 # Need to serialize attributes enums still
Kevin Chengfea5a372021-10-11 18:38:47 +0000520 if self.attributes is not None:
521 fb_attributes = self.attributes.serialize(builder)
522
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800523 TosaOperator.Start(builder)
524 TosaOperator.AddOp(builder, self.op)
525 TosaOperator.AddInputs(builder, fb_inputs)
526 TosaOperator.AddOutputs(builder, fb_outputs)
Kevin Chengfea5a372021-10-11 18:38:47 +0000527 if self.attributes is not None:
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800528 TosaOperator.AddAttributeType(builder, self.attributes.utype)
529 TosaOperator.AddAttribute(builder, fb_attributes)
Kevin Chengfea5a372021-10-11 18:38:47 +0000530
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800531 return TosaOperator.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000532
533
534class TosaSerializerBasicBlock:
535 def __init__(self, name):
536 self.name = name
537 self.operators = []
538
539 # Dict assures uniqueness, but allows us to look up by name
540 self.tensors = dict()
541
542 self.inputs = []
543 self.outputs = []
544
545 def addTensor(
546 self,
547 name,
548 shape,
549 dtype,
550 data=None,
551 placeholderFilename=None,
552 ):
Jeremy Johnson9b225172021-12-14 16:34:47 +0000553 if name not in self.tensors:
Kevin Chengfea5a372021-10-11 18:38:47 +0000554 self.tensors[name] = TosaSerializerTensor(
555 name, shape, dtype, data, placeholderFilename
556 )
557
558 return self.tensors[name]
559
560 def addInput(self, name):
561 self.inputs.append(name)
562
563 def addOutput(self, name):
564 self.outputs.append(name)
565
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000566 def addOperator(self, op, inputs, outputs, attributes=None):
567 self.operators.append(TosaSerializerOperator(op, inputs, outputs, attributes))
Kevin Chengfea5a372021-10-11 18:38:47 +0000568
569 def serialize(self, builder):
570 fb_name = builder.CreateString(self.name)
571 fbv_inputs = TosaSerializer.serializeStrVec(
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800572 builder, list(self.inputs), TosaBasicBlock.StartInputsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000573 )
574 fbv_outputs = TosaSerializer.serializeStrVec(
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800575 builder, list(self.outputs), TosaBasicBlock.StartOutputsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000576 )
577 fbv_tensors = TosaSerializer.serializeObjVec(
578 builder,
579 list(self.tensors.values()),
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800580 TosaBasicBlock.StartTensorsVector,
Kevin Chengfea5a372021-10-11 18:38:47 +0000581 )
582 fbv_operators = TosaSerializer.serializeObjVec(
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800583 builder, self.operators, TosaBasicBlock.StartOperatorsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000584 )
585
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800586 TosaBasicBlock.Start(builder)
587 TosaBasicBlock.AddName(builder, fb_name)
588 TosaBasicBlock.AddInputs(builder, fbv_inputs)
589 TosaBasicBlock.AddOutputs(builder, fbv_outputs)
590 TosaBasicBlock.AddTensors(builder, fbv_tensors)
591 TosaBasicBlock.AddOperators(builder, fbv_operators)
592 return TosaBasicBlock.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000593
594
595@unique
596class TensorDir(IntEnum):
597 PLACEHOLDER = 0
598 CONST = 1
599 INTERMEDIATE = 2
600 RESULT = 3
601
602
603class TosaSerializer:
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100604 def __init__(self, pathPrefix, saveConstsToFile=False):
Eric Kunzeae906de2022-05-30 22:40:47 -0700605 self.add_compat_methods()
Kevin Chengfea5a372021-10-11 18:38:47 +0000606 # Get the global TOSA version if not already defined
Kevin Chengfea5a372021-10-11 18:38:47 +0000607
608 self.builder = flatbuffers.Builder(0)
609
610 self.basicBlocks = []
611 self.startBasicBlock("main")
612 self.pathPrefix = pathPrefix
613
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100614 # Enables inspection of constant data outside of graph
615 self.saveConstsToFile = saveConstsToFile
616
Kevin Chengfea5a372021-10-11 18:38:47 +0000617 # Indicies used for adding/naming tensors
618 self.currInputIdx = 0
619 self.currConstIdx = 0
620 self.currLayerIdx = 1
621 self.currResultIdx = 0
622
623 # Is this an illegal test that is expected to fail?
Jeremy Johnson9b225172021-12-14 16:34:47 +0000624 self.expectedReturnCode = 0
Kevin Chengfea5a372021-10-11 18:38:47 +0000625 self.expectedFailure = False
626 self.expectedFailureDesc = ""
627
628 def __str__(self):
629 str = ""
630 for bb in self.basicBlocks:
631 str = str + bb.__str__()
632 return str
633
634 def addPlaceholder(self, shape, dtype, vals):
635 if not self.currBasicBlock:
636 raise Exception("addTensor called without valid basic block")
637
638 name = "input-{}".format(self.currInputIdx)
639 filename = "{}.npy".format(name)
640 self.currInputIdx = self.currInputIdx + 1
641
642 tens = self.currBasicBlock.addTensor(name, shape, dtype, None, filename)
643 # This is always an input to the block
644 self.currBasicBlock.addInput(name)
645
646 if vals is not None:
647 np.save(os.path.join(self.pathPrefix, filename), vals, False)
648
649 return tens
650
651 def addConst(self, shape, dtype, vals):
652 if not self.currBasicBlock:
653 raise Exception("addTensor called without valid basic block")
654
655 name = "const-{}".format(self.currInputIdx)
Kevin Chengfea5a372021-10-11 18:38:47 +0000656 self.currInputIdx = self.currInputIdx + 1
657
658 tens = self.currBasicBlock.addTensor(name, shape, dtype, vals)
659 # Add the operator now
Jeremy Johnson9b225172021-12-14 16:34:47 +0000660 self.currBasicBlock.addOperator(TosaOp.Op().CONST, [], name)
Kevin Chengfea5a372021-10-11 18:38:47 +0000661
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100662 if self.saveConstsToFile:
663 filename = "{}.npy".format(name)
664 np.save(os.path.join(self.pathPrefix, filename), vals, False)
665
Kevin Chengfea5a372021-10-11 18:38:47 +0000666 return tens
667
668 def addIntermediate(self, shape, dtype):
669
670 if not self.currBasicBlock:
671 raise Exception("addTensor called without valid basic block")
672
673 name = "layer-{}".format(self.currLayerIdx)
674 self.currLayerIdx = self.currLayerIdx + 1
675
676 tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
677
678 return tens
679
680 def addInputTensor(self, tensor):
681 self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype)
682 self.currBasicBlock.addInput(tensor.name)
683
684 def addOutputTensor(self, tensor):
685 self.currBasicBlock.addOutput(tensor.name)
686
687 def addOutput(self, shape, dtype):
688 if not self.currBasicBlock:
689 raise Exception("addTensor called without valid basic block")
690
691 name = "result-{}".format(self.currResultIdx)
692 self.currResultIdx = self.currResultIdx + 1
693
694 tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
695 self.currBasicBlock.addOutput(name)
696 return tens
697
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000698 def addOperator(self, op, inputs, outputs, attributes=None):
Kevin Chengfea5a372021-10-11 18:38:47 +0000699
Jeremy Johnson9b225172021-12-14 16:34:47 +0000700 if op == TosaOp.Op().CONST:
Kevin Chengfea5a372021-10-11 18:38:47 +0000701 raise Exception("Use addConstTensor() to add CONST ops")
702
703 return self.currBasicBlock.addOperator(
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000704 op,
705 inputs,
706 outputs,
707 attributes,
Kevin Chengfea5a372021-10-11 18:38:47 +0000708 )
709
Jeremy Johnson9b225172021-12-14 16:34:47 +0000710 def setExpectedReturnCode(self, val, fail, desc=""):
Kevin Chengfea5a372021-10-11 18:38:47 +0000711
712 self.expectedReturnCode = val
713 self.expectedFailureDesc = desc
Jeremy Johnson9b225172021-12-14 16:34:47 +0000714 self.expectedFailure = fail
Kevin Chengfea5a372021-10-11 18:38:47 +0000715
716 def serialize(self):
717
718 builder = self.builder
719
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800720 Version.Start(builder)
721 Version.Add_major(builder, TOSA_VERSION[0])
722 Version.Add_minor(builder, TOSA_VERSION[1])
723 Version.Add_patch(builder, TOSA_VERSION[2])
724 Version.Add_draft(builder, TOSA_VERSION[3])
725 version = Version.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000726
727 fbv_bb = TosaSerializer.serializeObjVec(
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800728 builder, self.basicBlocks, TosaGraph.StartBlocksVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000729 )
730
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800731 TosaGraph.Start(builder)
732 TosaGraph.AddVersion(builder, version)
733 TosaGraph.AddBlocks(builder, fbv_bb)
734 graph = TosaGraph.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000735
Eric Kunzee6596402022-06-09 21:27:36 +0000736 self.builder.Finish(graph, TOSA_GRAPH_IDENTIFIER)
Kevin Chengfea5a372021-10-11 18:38:47 +0000737 return self.builder.Output()
738
739 def writeJson(self, tosa_filename):
740 """Write a json test file so that it is fairly easy to pick up the test
741 and generate commands for third party tool"""
742 test_desc = dict()
743
744 test_desc["tosa_file"] = tosa_filename
745 ifm_name = []
746 ifm_file = []
747 ofm_name = []
748 ofm_file = []
749
750 for b in self.basicBlocks:
751 if b.name == "main":
752 for i in b.inputs:
753 ifm_name.append(i)
754 ifm_file.append(b.tensors[i].placeholderFilename)
755 for o in b.outputs:
756 ofm_name.append(o)
Jeremy Johnson9b225172021-12-14 16:34:47 +0000757 # Make up an OFM filename here. One isn't generated until the
758 # reference tool is run, so any name is a good name
Kevin Chengfea5a372021-10-11 18:38:47 +0000759 ofm_file.append("ref-{}.npy".format(o))
760
761 test_desc["ifm_name"] = ifm_name
762 test_desc["ifm_file"] = ifm_file
763 test_desc["ofm_name"] = ofm_name
764 test_desc["ofm_file"] = ofm_file
765 test_desc["expected_return_code"] = self.expectedReturnCode
766 test_desc["expected_failure"] = self.expectedFailure
767 if self.expectedFailureDesc:
768 test_desc["expected_failure_desc"] = self.expectedFailureDesc
769
770 return json.dumps(test_desc, indent=" ")
771
772 def startBasicBlock(self, name):
773 self.currBasicBlock = TosaSerializerBasicBlock(name)
774 self.basicBlocks.append(self.currBasicBlock)
775
776 @staticmethod
777 def serializeStrVec(builder, vec, start_fcn):
778 fb_strs = [builder.CreateString(i) for i in vec]
779 start_fcn(builder, len(fb_strs))
780 for s in fb_strs[::-1]:
781 builder.PrependUOffsetTRelative(s)
Eric Kunzeae906de2022-05-30 22:40:47 -0700782 try:
783 return builder.EndVector()
784 except TypeError:
785 return builder.EndVector(len(vec))
Kevin Chengfea5a372021-10-11 18:38:47 +0000786
787 @staticmethod
788 def serializeUint8Vec(builder, vec):
789 builder.StartVector(1, len(vec), 8)
790 for v in vec[::-1]:
791 builder.PrependUint8(v)
792 try:
793 return builder.EndVector()
794 except TypeError:
795 return builder.EndVector(len(vec))
796
797 @staticmethod
TatWai Chong49b1ca62022-06-10 01:49:13 -0700798 def serializeInt16Vec(builder, vec):
799 builder.StartVector(2, len(vec), 4)
800 for v in vec[::-1]:
801 builder.PrependInt16(v)
802 try:
803 return builder.EndVector()
804 except TypeError:
805 return builder.EndVector(len(vec))
806
807 @staticmethod
Kevin Chengfea5a372021-10-11 18:38:47 +0000808 def serializeInt32Vec(builder, vec):
809 builder.StartVector(4, len(vec), 4)
810 for v in vec[::-1]:
811 builder.PrependInt32(v)
812 try:
813 return builder.EndVector()
814 except TypeError:
815 return builder.EndVector(len(vec))
816
817 @staticmethod
818 def serializeFpVec(builder, vec):
819 builder.StartVector(4, len(vec), 4)
820 for v in vec[::-1]:
821 builder.PrependFloat32(v)
822 try:
823 return builder.EndVector()
824 except TypeError:
825 return builder.EndVector(len(vec))
826
827 @staticmethod
828 def serializeObjVec(builder, vec, start_fcn):
829 serialized_vec = []
830 for v in vec[::-1]:
831 serialized_vec.append(v.serialize(builder))
832
833 start_fcn(builder, len(vec))
834 for v in serialized_vec:
835 builder.PrependUOffsetTRelative(v)
836 try:
837 return builder.EndVector()
838 except TypeError:
839 return builder.EndVector(len(vec))
840
841 @staticmethod
842 def toList(val):
843 if isinstance(val, list):
844 return val
845 else:
846 return [val]
Eric Kunzeae906de2022-05-30 22:40:47 -0700847
848 # Remove when switching to flatbuffers 2.0
849 # contains a mapping of the deprecated 1.12 method to the 2.0 version
850
851 def add_compat_methods(self):
852
853 from tosa import ArithmeticRightShiftAttribute
854
855 if not hasattr(ArithmeticRightShiftAttribute, "Start"):
856 ArithmeticRightShiftAttribute.Start = (
857 ArithmeticRightShiftAttribute.ArithmeticRightShiftAttributeStart
858 )
859 ArithmeticRightShiftAttribute.AddRound = (
860 ArithmeticRightShiftAttribute.ArithmeticRightShiftAttributeAddRound
861 )
862 ArithmeticRightShiftAttribute.End = (
863 ArithmeticRightShiftAttribute.ArithmeticRightShiftAttributeEnd
864 )
865 from tosa import AxisAttribute
866
867 if not hasattr(AxisAttribute, "Start"):
868 AxisAttribute.Start = AxisAttribute.AxisAttributeStart
869 AxisAttribute.AddAxis = AxisAttribute.AxisAttributeAddAxis
870 AxisAttribute.End = AxisAttribute.AxisAttributeEnd
871 from tosa import ClampAttribute
872
873 if not hasattr(ClampAttribute, "Start"):
874 ClampAttribute.Start = ClampAttribute.ClampAttributeStart
875 ClampAttribute.AddMinInt = ClampAttribute.ClampAttributeAddMinInt
876 ClampAttribute.AddMaxInt = ClampAttribute.ClampAttributeAddMaxInt
877 ClampAttribute.AddMinFp = ClampAttribute.ClampAttributeAddMinFp
878 ClampAttribute.AddMaxFp = ClampAttribute.ClampAttributeAddMaxFp
879 ClampAttribute.End = ClampAttribute.ClampAttributeEnd
880 from tosa import CondIfAttribute
881
882 if not hasattr(CondIfAttribute, "Start"):
883 CondIfAttribute.Start = CondIfAttribute.CondIfAttributeStart
884 CondIfAttribute.AddThenBranch = CondIfAttribute.CondIfAttributeAddThenBranch
885 CondIfAttribute.AddElseBranch = CondIfAttribute.CondIfAttributeAddElseBranch
886 CondIfAttribute.End = CondIfAttribute.CondIfAttributeEnd
887 from tosa import ConvAttribute
888
889 if not hasattr(ConvAttribute, "Start"):
890 ConvAttribute.Start = ConvAttribute.ConvAttributeStart
891 ConvAttribute.AddPad = ConvAttribute.ConvAttributeAddPad
892 ConvAttribute.StartPadVector = ConvAttribute.ConvAttributeStartPadVector
893 ConvAttribute.AddStride = ConvAttribute.ConvAttributeAddStride
894 ConvAttribute.StartStrideVector = (
895 ConvAttribute.ConvAttributeStartStrideVector
896 )
897 ConvAttribute.AddDilation = ConvAttribute.ConvAttributeAddDilation
898 ConvAttribute.StartDilationVector = (
899 ConvAttribute.ConvAttributeStartDilationVector
900 )
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000901 ConvAttribute.AddInputZp = ConvAttribute.ConvAttributeAddInputZp
902 ConvAttribute.AddWeightZp = ConvAttribute.ConvAttributeAddWeightZp
James Ward485a11d2022-08-05 13:48:37 +0100903 ConvAttribute.AddAccumDtype = ConvAttribute.ConvAttributeAddAccumDtype
Eric Kunzeae906de2022-05-30 22:40:47 -0700904 ConvAttribute.End = ConvAttribute.ConvAttributeEnd
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000905 from tosa import FullyConnectedAttribute
Eric Kunzeae906de2022-05-30 22:40:47 -0700906
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000907 if not hasattr(FullyConnectedAttribute, "Start"):
908 FullyConnectedAttribute.Start = (
909 FullyConnectedAttribute.FullyConnectedAttributeStart
910 )
911 FullyConnectedAttribute.AddInputZp = (
912 FullyConnectedAttribute.FullyConnectedAttributeAddInputZp
913 )
914 FullyConnectedAttribute.AddWeightZp = (
915 FullyConnectedAttribute.FullyConnectedAttributeAddWeightZp
916 )
James Ward485a11d2022-08-05 13:48:37 +0100917 FullyConnectedAttribute.AddAccumDtype = (
918 FullyConnectedAttribute.FullyConnectedAttributeAddAccumDtype
919 )
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000920 FullyConnectedAttribute.End = (
921 FullyConnectedAttribute.FullyConnectedAttributeEnd
922 )
923 from tosa import MatMulAttribute
Eric Kunzeae906de2022-05-30 22:40:47 -0700924
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000925 if not hasattr(MatMulAttribute, "Start"):
926 MatMulAttribute.Start = MatMulAttribute.MatMulAttributeStart
927 MatMulAttribute.AddAZp = MatMulAttribute.MatMulAttributeAddAZp
928 MatMulAttribute.AddBZp = MatMulAttribute.MatMulAttributeAddBZp
James Ward485a11d2022-08-05 13:48:37 +0100929 MatMulAttribute.AddAccumDtype = MatMulAttribute.MatMulAttributeAddAccumDtype
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000930 MatMulAttribute.End = MatMulAttribute.MatMulAttributeEnd
931 from tosa import PoolAttribute
932
933 if not hasattr(PoolAttribute, "Start"):
934 PoolAttribute.Start = PoolAttribute.PoolAttributeStart
935 PoolAttribute.AddPad = PoolAttribute.PoolAttributeAddPad
936 PoolAttribute.StartPadVector = PoolAttribute.PoolAttributeStartPadVector
937 PoolAttribute.AddKernel = PoolAttribute.PoolAttributeAddKernel
938 PoolAttribute.StartKernelVector = (
939 PoolAttribute.PoolAttributeStartKernelVector
940 )
941 PoolAttribute.AddStride = PoolAttribute.PoolAttributeAddStride
942 PoolAttribute.StartStrideVector = (
943 PoolAttribute.PoolAttributeStartStrideVector
944 )
James Ward485a11d2022-08-05 13:48:37 +0100945 PoolAttribute.AddAccumDtype = PoolAttribute.PoolAttributeAddAccumDtype
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000946 PoolAttribute.AddInputZp = PoolAttribute.PoolAttributeAddInputZp
947 PoolAttribute.AddOutputZp = PoolAttribute.PoolAttributeAddOutputZp
948 PoolAttribute.End = PoolAttribute.PoolAttributeEnd
Eric Kunzeae906de2022-05-30 22:40:47 -0700949 from tosa import MulAttribute
950
951 if not hasattr(MulAttribute, "Start"):
952 MulAttribute.Start = MulAttribute.MulAttributeStart
953 MulAttribute.AddShift = MulAttribute.MulAttributeAddShift
954 MulAttribute.End = MulAttribute.MulAttributeEnd
955 from tosa import PadAttribute
956
957 if not hasattr(PadAttribute, "Start"):
958 PadAttribute.Start = PadAttribute.PadAttributeStart
959 PadAttribute.AddPadding = PadAttribute.PadAttributeAddPadding
960 PadAttribute.StartPaddingVector = (
961 PadAttribute.PadAttributeStartPaddingVector
962 )
963 PadAttribute.AddPadConstInt = PadAttribute.PadAttributeAddPadConstInt
964 PadAttribute.AddPadConstFp = PadAttribute.PadAttributeAddPadConstFp
965 PadAttribute.End = PadAttribute.PadAttributeEnd
Eric Kunzeae906de2022-05-30 22:40:47 -0700966 from tosa import PoolAttribute
967
968 if not hasattr(PoolAttribute, "Start"):
969 PoolAttribute.Start = PoolAttribute.PoolAttributeStart
970 PoolAttribute.AddPad = PoolAttribute.PoolAttributeAddPad
971 PoolAttribute.StartPadVector = PoolAttribute.PoolAttributeStartPadVector
972 PoolAttribute.AddKernel = PoolAttribute.PoolAttributeAddKernel
973 PoolAttribute.StartKernelVector = (
974 PoolAttribute.PoolAttributeStartKernelVector
975 )
976 PoolAttribute.AddStride = PoolAttribute.PoolAttributeAddStride
977 PoolAttribute.StartStrideVector = (
978 PoolAttribute.PoolAttributeStartStrideVector
979 )
James Ward485a11d2022-08-05 13:48:37 +0100980 PoolAttribute.AddAccumDtype = PoolAttribute.PoolAttributeAddAccumDtype
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000981 PoolAttribute.AddInputZp = PoolAttribute.PoolAttributeAddInputZp
982 PoolAttribute.AddOutputZp = PoolAttribute.PoolAttributeAddOutputZp
Eric Kunzeae906de2022-05-30 22:40:47 -0700983 PoolAttribute.End = PoolAttribute.PoolAttributeEnd
984 from tosa import RescaleAttribute
985
986 if not hasattr(RescaleAttribute, "Start"):
987 RescaleAttribute.Start = RescaleAttribute.RescaleAttributeStart
988 RescaleAttribute.AddInputZp = RescaleAttribute.RescaleAttributeAddInputZp
989 RescaleAttribute.AddOutputZp = RescaleAttribute.RescaleAttributeAddOutputZp
990 RescaleAttribute.AddMultiplier = (
991 RescaleAttribute.RescaleAttributeAddMultiplier
992 )
993 RescaleAttribute.StartMultiplierVector = (
994 RescaleAttribute.RescaleAttributeStartMultiplierVector
995 )
996 RescaleAttribute.AddShift = RescaleAttribute.RescaleAttributeAddShift
997 RescaleAttribute.StartShiftVector = (
998 RescaleAttribute.RescaleAttributeStartShiftVector
999 )
1000 RescaleAttribute.AddScale32 = RescaleAttribute.RescaleAttributeAddScale32
1001 RescaleAttribute.AddDoubleRound = (
1002 RescaleAttribute.RescaleAttributeAddDoubleRound
1003 )
1004 RescaleAttribute.AddPerChannel = (
1005 RescaleAttribute.RescaleAttributeAddPerChannel
1006 )
1007 RescaleAttribute.End = RescaleAttribute.RescaleAttributeEnd
1008 from tosa import ReshapeAttribute
1009
1010 if not hasattr(ReshapeAttribute, "Start"):
1011 ReshapeAttribute.Start = ReshapeAttribute.ReshapeAttributeStart
1012 ReshapeAttribute.AddNewShape = ReshapeAttribute.ReshapeAttributeAddNewShape
1013 ReshapeAttribute.StartNewShapeVector = (
1014 ReshapeAttribute.ReshapeAttributeStartNewShapeVector
1015 )
1016 ReshapeAttribute.End = ReshapeAttribute.ReshapeAttributeEnd
1017 from tosa import ResizeAttribute
1018
1019 if not hasattr(ResizeAttribute, "Start"):
1020 ResizeAttribute.Start = ResizeAttribute.ResizeAttributeStart
TatWai Chong49b1ca62022-06-10 01:49:13 -07001021 ResizeAttribute.AddScale = ResizeAttribute.ResizeAttributeAddScale
1022 ResizeAttribute.StartScaleVector = (
1023 ResizeAttribute.ResizeAttributeStartScaleVector
Eric Kunzeae906de2022-05-30 22:40:47 -07001024 )
1025 ResizeAttribute.AddOffset = ResizeAttribute.ResizeAttributeAddOffset
1026 ResizeAttribute.StartOffsetVector = (
1027 ResizeAttribute.ResizeAttributeStartOffsetVector
1028 )
TatWai Chong49b1ca62022-06-10 01:49:13 -07001029 ResizeAttribute.AddBorder = ResizeAttribute.ResizeAttributeAddBorder
1030 ResizeAttribute.StartBorderVector = (
1031 ResizeAttribute.ResizeAttributeStartBorderVector
Eric Kunzeae906de2022-05-30 22:40:47 -07001032 )
1033 ResizeAttribute.AddMode = ResizeAttribute.ResizeAttributeAddMode
1034 ResizeAttribute.End = ResizeAttribute.ResizeAttributeEnd
1035 from tosa import SliceAttribute
1036
1037 if not hasattr(SliceAttribute, "Start"):
1038 SliceAttribute.Start = SliceAttribute.SliceAttributeStart
1039 SliceAttribute.AddStart = SliceAttribute.SliceAttributeAddStart
1040 SliceAttribute.StartStartVector = (
1041 SliceAttribute.SliceAttributeStartStartVector
1042 )
1043 SliceAttribute.AddSize = SliceAttribute.SliceAttributeAddSize
1044 SliceAttribute.StartSizeVector = (
1045 SliceAttribute.SliceAttributeStartSizeVector
1046 )
1047 SliceAttribute.End = SliceAttribute.SliceAttributeEnd
1048 from tosa import TableAttribute
1049
1050 if not hasattr(TableAttribute, "Start"):
1051 TableAttribute.Start = TableAttribute.TableAttributeStart
1052 TableAttribute.AddTable = TableAttribute.TableAttributeAddTable
1053 TableAttribute.StartTableVector = (
1054 TableAttribute.TableAttributeStartTableVector
1055 )
1056 TableAttribute.End = TableAttribute.TableAttributeEnd
1057 from tosa import TileAttribute
1058
1059 if not hasattr(TileAttribute, "Start"):
1060 TileAttribute.Start = TileAttribute.TileAttributeStart
1061 TileAttribute.AddMultiples = TileAttribute.TileAttributeAddMultiples
1062 TileAttribute.StartMultiplesVector = (
1063 TileAttribute.TileAttributeStartMultiplesVector
1064 )
1065 TileAttribute.End = TileAttribute.TileAttributeEnd
1066 from tosa import TosaBasicBlock
1067
1068 if not hasattr(TosaBasicBlock, "Start"):
1069 TosaBasicBlock.Start = TosaBasicBlock.TosaBasicBlockStart
1070 TosaBasicBlock.AddName = TosaBasicBlock.TosaBasicBlockAddName
1071 TosaBasicBlock.AddOperators = TosaBasicBlock.TosaBasicBlockAddOperators
1072 TosaBasicBlock.StartOperatorsVector = (
1073 TosaBasicBlock.TosaBasicBlockStartOperatorsVector
1074 )
1075 TosaBasicBlock.AddTensors = TosaBasicBlock.TosaBasicBlockAddTensors
1076 TosaBasicBlock.StartTensorsVector = (
1077 TosaBasicBlock.TosaBasicBlockStartTensorsVector
1078 )
1079 TosaBasicBlock.AddInputs = TosaBasicBlock.TosaBasicBlockAddInputs
1080 TosaBasicBlock.StartInputsVector = (
1081 TosaBasicBlock.TosaBasicBlockStartInputsVector
1082 )
1083 TosaBasicBlock.AddOutputs = TosaBasicBlock.TosaBasicBlockAddOutputs
1084 TosaBasicBlock.StartOutputsVector = (
1085 TosaBasicBlock.TosaBasicBlockStartOutputsVector
1086 )
1087 TosaBasicBlock.End = TosaBasicBlock.TosaBasicBlockEnd
1088 from tosa import TosaGraph
1089
1090 if not hasattr(TosaGraph, "Start"):
1091 TosaGraph.Start = TosaGraph.TosaGraphStart
1092 TosaGraph.AddVersion = TosaGraph.TosaGraphAddVersion
1093 TosaGraph.AddBlocks = TosaGraph.TosaGraphAddBlocks
1094 TosaGraph.StartBlocksVector = TosaGraph.TosaGraphStartBlocksVector
1095 TosaGraph.End = TosaGraph.TosaGraphEnd
1096 from tosa import TosaOperator
1097
1098 if not hasattr(TosaOperator, "Start"):
1099 TosaOperator.Start = TosaOperator.TosaOperatorStart
1100 TosaOperator.AddOp = TosaOperator.TosaOperatorAddOp
1101 TosaOperator.AddAttributeType = TosaOperator.TosaOperatorAddAttributeType
1102 TosaOperator.AddAttribute = TosaOperator.TosaOperatorAddAttribute
1103 TosaOperator.AddInputs = TosaOperator.TosaOperatorAddInputs
1104 TosaOperator.StartInputsVector = TosaOperator.TosaOperatorStartInputsVector
1105 TosaOperator.AddOutputs = TosaOperator.TosaOperatorAddOutputs
1106 TosaOperator.StartOutputsVector = (
1107 TosaOperator.TosaOperatorStartOutputsVector
1108 )
Eric Kunzeae906de2022-05-30 22:40:47 -07001109 TosaOperator.End = TosaOperator.TosaOperatorEnd
1110 from tosa import TosaTensor
1111
1112 if not hasattr(TosaTensor, "Start"):
1113 TosaTensor.Start = TosaTensor.TosaTensorStart
1114 TosaTensor.AddName = TosaTensor.TosaTensorAddName
1115 TosaTensor.AddShape = TosaTensor.TosaTensorAddShape
1116 TosaTensor.StartShapeVector = TosaTensor.TosaTensorStartShapeVector
1117 TosaTensor.AddType = TosaTensor.TosaTensorAddType
1118 TosaTensor.AddData = TosaTensor.TosaTensorAddData
1119 TosaTensor.StartDataVector = TosaTensor.TosaTensorStartDataVector
1120 TosaTensor.End = TosaTensor.TosaTensorEnd
1121 from tosa import TransposeAttribute
1122
1123 if not hasattr(TransposeAttribute, "Start"):
1124 TransposeAttribute.Start = TransposeAttribute.TransposeAttributeStart
1125 TransposeAttribute.AddPerms = TransposeAttribute.TransposeAttributeAddPerms
1126 TransposeAttribute.StartPermsVector = (
1127 TransposeAttribute.TransposeAttributeStartPermsVector
1128 )
1129 TransposeAttribute.End = TransposeAttribute.TransposeAttributeEnd
1130 from tosa import TransposeConvAttribute
1131
1132 if not hasattr(TransposeConvAttribute, "Start"):
1133 TransposeConvAttribute.Start = (
1134 TransposeConvAttribute.TransposeConvAttributeStart
1135 )
Eric Kunze4c3537d2022-06-13 17:21:48 -07001136 TransposeConvAttribute.AddOutPad = (
1137 TransposeConvAttribute.TransposeConvAttributeAddOutPad
Eric Kunzeae906de2022-05-30 22:40:47 -07001138 )
Eric Kunze4c3537d2022-06-13 17:21:48 -07001139 TransposeConvAttribute.StartOutPadVector = (
1140 TransposeConvAttribute.TransposeConvAttributeStartOutPadVector
Eric Kunzeae906de2022-05-30 22:40:47 -07001141 )
1142 TransposeConvAttribute.AddStride = (
1143 TransposeConvAttribute.TransposeConvAttributeAddStride
1144 )
1145 TransposeConvAttribute.StartStrideVector = (
1146 TransposeConvAttribute.TransposeConvAttributeStartStrideVector
1147 )
Eric Kunzeae906de2022-05-30 22:40:47 -07001148 TransposeConvAttribute.AddOutputShape = (
1149 TransposeConvAttribute.TransposeConvAttributeAddOutputShape
1150 )
1151 TransposeConvAttribute.StartOutputShapeVector = (
1152 TransposeConvAttribute.TransposeConvAttributeStartOutputShapeVector
1153 )
Eric Kunzebdcc3fe2022-06-07 05:17:37 +00001154 TransposeConvAttribute.AddInputZp = (
1155 TransposeConvAttribute.TransposeConvAttributeAddInputZp
1156 )
1157 TransposeConvAttribute.AddWeightZp = (
1158 TransposeConvAttribute.TransposeConvAttributeAddWeightZp
1159 )
James Ward485a11d2022-08-05 13:48:37 +01001160 TransposeConvAttribute.AddAccumDtype = (
1161 TransposeConvAttribute.TransposeConvAttributeAddAccumDtype
1162 )
Eric Kunzeae906de2022-05-30 22:40:47 -07001163 TransposeConvAttribute.End = (
1164 TransposeConvAttribute.TransposeConvAttributeEnd
1165 )
Eric Kunzeae906de2022-05-30 22:40:47 -07001166 from tosa import Version
1167
1168 if not hasattr(Version, "Start"):
1169 Version.Start = Version.VersionStart
1170 Version.Add_major = Version.VersionAdd_major
1171 Version.Add_minor = Version.VersionAdd_minor
1172 Version.Add_patch = Version.VersionAdd_patch
1173 Version.Add_draft = Version.VersionAdd_draft
1174 Version.End = Version.VersionEnd
Eric Kunzebdcc3fe2022-06-07 05:17:37 +00001175 from tosa import MatMulAttribute
1176
1177 if not hasattr(MatMulAttribute, "Start"):
1178 MatMulAttribute.Start = MatMulAttribute.MatMulAttributeStart
1179 MatMulAttribute.AddAZp = MatMulAttribute.MatMulAttributeAddAZp
1180 MatMulAttribute.AddBZp = MatMulAttribute.MatMulAttributeAddBZp
1181 MatMulAttribute.End = MatMulAttribute.MatMulAttributeEnd
1182 from tosa import FullyConnectedAttribute
1183
1184 if not hasattr(FullyConnectedAttribute, "Start"):
1185 FullyConnectedAttribute.Start = (
1186 FullyConnectedAttribute.FullyConnectedAttributeStart
1187 )
1188 FullyConnectedAttribute.AddInputZp = (
1189 FullyConnectedAttribute.FullyConnectedAttributeAddInputZp
1190 )
1191 FullyConnectedAttribute.AddWeightZp = (
1192 FullyConnectedAttribute.FullyConnectedAttributeAddWeightZp
1193 )
1194 FullyConnectedAttribute.End = (
1195 FullyConnectedAttribute.FullyConnectedAttributeEnd
1196 )
1197 from tosa import NegateAttribute
1198
1199 if not hasattr(NegateAttribute, "Start"):
1200 NegateAttribute.Start = NegateAttribute.NegateAttributeStart
1201 NegateAttribute.AddInput1Zp = NegateAttribute.NegateAttributeAddInput1Zp
1202 NegateAttribute.AddOutputZp = NegateAttribute.NegateAttributeAddOutputZp
1203 NegateAttribute.End = NegateAttribute.NegateAttributeEnd
Eric Kunzeae906de2022-05-30 22:40:47 -07001204 from tosa import WhileLoopAttribute
1205
1206 if not hasattr(WhileLoopAttribute, "Start"):
1207 WhileLoopAttribute.Start = WhileLoopAttribute.WhileLoopAttributeStart
1208 WhileLoopAttribute.AddCondBranch = (
1209 WhileLoopAttribute.WhileLoopAttributeAddCondBranch
1210 )
1211 WhileLoopAttribute.AddBodyBranch = (
1212 WhileLoopAttribute.WhileLoopAttributeAddBodyBranch
1213 )
1214 WhileLoopAttribute.End = WhileLoopAttribute.WhileLoopAttributeEnd