blob: 5ec45d136a5f2340869948fd878948efad4a05b9 [file] [log] [blame]
Jerry Ge1eb85042023-01-06 14:19:14 -08001# Copyright (c) 2020-2023, ARM Limited.
Kevin Chengfea5a372021-10-11 18:38:47 +00002#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
Kevin Chengfea5a372021-10-11 18:38:47 +000015import os
Kevin Chengfea5a372021-10-11 18:38:47 +000016import json
17import flatbuffers
18import numpy as np
19import struct
Jeremy Johnson9b225172021-12-14 16:34:47 +000020from enum import IntEnum, unique
Kevin Chengfea5a372021-10-11 18:38:47 +000021from tosa import (
22 TosaGraph,
Jerry Ge1eb85042023-01-06 14:19:14 -080023 TosaRegion,
Kevin Chengfea5a372021-10-11 18:38:47 +000024 TosaBasicBlock,
25 TosaTensor,
26 TosaOperator,
Kevin Chengfea5a372021-10-11 18:38:47 +000027 Version,
28)
Jeremy Johnson9b225172021-12-14 16:34:47 +000029import tosa.DType as TosaDType
30import tosa.Op as TosaOp
Kevin Chengfea5a372021-10-11 18:38:47 +000031
Kevin Chenge6563f52021-10-20 12:12:02 -070032# Keep version number in sync with the version default value with schema/tosa.fbs
Kevin Chengb97cb1d2021-10-14 11:53:39 -070033TOSA_VERSION_MAJOR = 0
Eric Kunze6388a092022-12-07 21:59:31 +000034TOSA_VERSION_MINOR = 51
Kevin Chengb97cb1d2021-10-14 11:53:39 -070035TOSA_VERSION_PATCH = 0
Eric Kunze6388a092022-12-07 21:59:31 +000036TOSA_VERSION_DRAFT = True
Jeremy Johnson9b225172021-12-14 16:34:47 +000037TOSA_VERSION = [
38 TOSA_VERSION_MAJOR,
39 TOSA_VERSION_MINOR,
40 TOSA_VERSION_PATCH,
41 TOSA_VERSION_DRAFT,
42]
Eric Kunzee6596402022-06-09 21:27:36 +000043
44# File identifier needs to be kept in sync with schema/tosa.fbs
45TOSA_GRAPH_IDENTIFIER = b"\x54\x4F\x53\x41"
46
Kevin Chengfea5a372021-10-11 18:38:47 +000047# With the way flatc generates its python types, there is no programatic way
48# to get string names for the integer types. Manually maintain a string table
49# here.
Jeremy Johnson9b225172021-12-14 16:34:47 +000050DType = TosaDType.DType()
Kevin Chengfea5a372021-10-11 18:38:47 +000051DTypeNames = [
52 "UNKNOWN",
53 "BOOL",
54 "UINT8",
55 "INT4",
56 "INT8",
57 "INT16",
58 "INT32",
59 "INT48",
Jeremy Johnsone1072a92022-09-27 12:44:11 +010060 "FP32",
Jeremy Johnson41027732022-05-25 17:52:29 +010061 "UINT16",
James Ward485a11d2022-08-05 13:48:37 +010062 "FP16",
James Ward34a62792022-10-18 17:27:40 +010063 "BF16",
Kevin Chengfea5a372021-10-11 18:38:47 +000064]
65
66ByteMask = np.uint64(0xFF)
67
68
69def dtype_str_to_val(name):
70
71 for i in range(len(DTypeNames)):
72 if name.casefold() == DTypeNames[i].casefold():
73 return i
74 raise Exception("Unable to parse DType name {}".format(name))
75
76
77class TosaSerializerUnion:
78 """This class handles encapsulating and serializing union types into flatbuffers"""
79
80 def __init__(self):
81
Jeremy Johnson9b225172021-12-14 16:34:47 +000082 # A tuple of the start and end functions.
83 # Set by the options constructors below
Kevin Chengfea5a372021-10-11 18:38:47 +000084 self.optFcns = None
85
Jeremy Johnson9b225172021-12-14 16:34:47 +000086 # The type from the tosa.Options enumeration.
87 # Set by the options constructors below.
Kevin Chengfea5a372021-10-11 18:38:47 +000088 self.utype = None
89
90 # Each of these lists is a tuple of the add function and the
91 # value being added. Set by the options constructors below.
92 self.ints = []
93 self.bools = []
94 self.floats = []
95 self.strings = []
TatWai Chong49b1ca62022-06-10 01:49:13 -070096 self.int16vecs = []
Kevin Chengfea5a372021-10-11 18:38:47 +000097 self.intvecs = []
98 self.fpvecs = []
99
100 def serialize(self, builder):
101
102 # We have to build strings and vectors first
103 strList = []
104 intVecList = []
105 fpVecList = []
106
107 for fcn, val in self.strings:
108 strList.append((fcn, builder.CreateString(val)))
109
110 for fcn, val in self.intvecs:
111 intVecList.append((fcn, TosaSerializer.serializeInt32Vec(builder, val)))
112
TatWai Chong49b1ca62022-06-10 01:49:13 -0700113 for fcn, val in self.int16vecs:
114 intVecList.append((fcn, TosaSerializer.serializeInt16Vec(builder, val)))
115
Kevin Chengfea5a372021-10-11 18:38:47 +0000116 for fcn, val in self.fpvecs:
117 fpVecList.append((fcn, TosaSerializer.serializeFpVec(builder, val)))
118
119 startFcn, endFcn = self.optFcns
120
121 # Then serialize the options object from the list of primitives and
122 # other serialized values
123 startFcn(builder)
124 for fcn, val in self.ints:
125 fcn(builder, val)
126
127 for fcn, val in self.bools:
128 fcn(builder, val)
129
130 for fcn, val in self.floats:
131 fcn(builder, val)
132
133 for fcn, val in strList:
134 fcn(builder, val)
135
136 for fcn, val in intVecList:
137 fcn(builder, val)
138
139 for fcn, val in fpVecList:
140 fcn(builder, val)
141
142 return endFcn(builder)
143
144
145class TosaSerializerAttribute(TosaSerializerUnion):
146 """This class handles encapsulating all of the enumerated types for attributes"""
147
148 def __init__(self):
149 super().__init__()
150
James Ward485a11d2022-08-05 13:48:37 +0100151 def PoolAttribute(
152 self,
153 kernel,
154 stride,
155 pad,
156 input_zp,
157 output_zp,
158 accum_dtype,
159 ):
Kevin Chengfea5a372021-10-11 18:38:47 +0000160 from tosa import PoolAttribute as a, Attribute
161
162 self.utype = Attribute.Attribute().PoolAttribute
163
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800164 self.optFcns = (a.Start, a.End)
TatWai Chong7be71652022-05-10 17:26:20 -0700165 self.intvecs.append((a.AddPad, pad))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800166 self.intvecs.append((a.AddKernel, kernel))
167 self.intvecs.append((a.AddStride, stride))
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000168 self.ints.append((a.AddInputZp, input_zp))
169 self.ints.append((a.AddOutputZp, output_zp))
James Ward485a11d2022-08-05 13:48:37 +0100170 self.ints.append((a.AddAccumDtype, accum_dtype))
Kevin Chengfea5a372021-10-11 18:38:47 +0000171
James Ward485a11d2022-08-05 13:48:37 +0100172 def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp, accum_dtype):
Kevin Chengfea5a372021-10-11 18:38:47 +0000173 from tosa import ConvAttribute as a, Attribute
174
175 self.utype = Attribute.Attribute().ConvAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800176 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000177
TatWai Chong7be71652022-05-10 17:26:20 -0700178 self.intvecs.append((a.AddPad, pad))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800179 self.intvecs.append((a.AddStride, stride))
180 self.intvecs.append((a.AddDilation, dilation))
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000181 self.ints.append((a.AddInputZp, input_zp))
182 self.ints.append((a.AddWeightZp, weight_zp))
James Ward485a11d2022-08-05 13:48:37 +0100183 self.ints.append((a.AddAccumDtype, accum_dtype))
Kevin Chengfea5a372021-10-11 18:38:47 +0000184
James Ward485a11d2022-08-05 13:48:37 +0100185 def TransposeConvAttribute(
186 self, outpad, stride, output_shape, input_zp, weight_zp, accum_dtype
187 ):
Kevin Chengfea5a372021-10-11 18:38:47 +0000188 from tosa import TransposeConvAttribute as a, Attribute
189
190 self.utype = Attribute.Attribute().TransposeConvAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800191 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000192
Eric Kunze4c3537d2022-06-13 17:21:48 -0700193 self.intvecs.append((a.AddOutPad, outpad))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800194 self.intvecs.append((a.AddStride, stride))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800195 self.intvecs.append((a.AddOutputShape, output_shape))
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000196 self.ints.append((a.AddInputZp, input_zp))
197 self.ints.append((a.AddWeightZp, weight_zp))
James Ward485a11d2022-08-05 13:48:37 +0100198 self.ints.append((a.AddAccumDtype, accum_dtype))
Kevin Chengfea5a372021-10-11 18:38:47 +0000199
Kevin Cheng38d214c2021-10-15 15:49:19 -0700200 def PadAttribute(self, padding, pad_const_int, pad_const_fp):
201 from tosa import PadAttribute as a, Attribute
Kevin Chengfea5a372021-10-11 18:38:47 +0000202
Kevin Cheng38d214c2021-10-15 15:49:19 -0700203 self.utype = Attribute.Attribute().PadAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800204 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000205
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800206 self.intvecs.append((a.AddPadding, padding))
207 self.ints.append((a.AddPadConstInt, pad_const_int))
208 self.floats.append((a.AddPadConstFp, pad_const_fp))
Kevin Chengfea5a372021-10-11 18:38:47 +0000209
210 def AxisAttribute(self, axis):
211 from tosa import AxisAttribute as a, Attribute
212
213 self.utype = Attribute.Attribute().AxisAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800214 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000215
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800216 self.ints.append((a.AddAxis, axis))
Kevin Chengfea5a372021-10-11 18:38:47 +0000217
TatWai Chong7be71652022-05-10 17:26:20 -0700218 def ReshapeAttribute(self, new_shape):
Kevin Chengfea5a372021-10-11 18:38:47 +0000219 from tosa import ReshapeAttribute as a, Attribute
220
221 self.utype = Attribute.Attribute().ReshapeAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800222 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000223
TatWai Chong7be71652022-05-10 17:26:20 -0700224 self.intvecs.append((a.AddNewShape, new_shape))
Kevin Chengfea5a372021-10-11 18:38:47 +0000225
TatWai Chong7be71652022-05-10 17:26:20 -0700226 def SliceAttribute(self, start, size):
Kevin Chengfea5a372021-10-11 18:38:47 +0000227 from tosa import SliceAttribute as a, Attribute
228
229 self.utype = Attribute.Attribute().SliceAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800230 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000231
TatWai Chong7be71652022-05-10 17:26:20 -0700232 self.intvecs.append((a.AddStart, start))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800233 self.intvecs.append((a.AddSize, size))
Kevin Chengfea5a372021-10-11 18:38:47 +0000234
235 def TileAttribute(self, multiples):
236 from tosa import TileAttribute as a, Attribute
237
238 self.utype = Attribute.Attribute().TileAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800239 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000240
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800241 self.intvecs.append((a.AddMultiples, multiples))
Kevin Chengfea5a372021-10-11 18:38:47 +0000242
TatWai Chong49b1ca62022-06-10 01:49:13 -0700243 def ResizeAttribute(self, scale, offset, border, mode):
Kevin Chengfea5a372021-10-11 18:38:47 +0000244 from tosa import ResizeAttribute as a, Attribute
245
246 self.utype = Attribute.Attribute().ResizeAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800247 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000248
TatWai Chong49b1ca62022-06-10 01:49:13 -0700249 self.int16vecs.append((a.AddScale, scale))
250 self.int16vecs.append((a.AddOffset, offset))
251 self.int16vecs.append((a.AddBorder, border))
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800252 self.ints.append((a.AddMode, mode))
Kevin Chengfea5a372021-10-11 18:38:47 +0000253
254 def ClampAttribute(self, minint, maxint, minfp, maxfp):
255 from tosa import ClampAttribute as a, Attribute
256
257 self.utype = Attribute.Attribute().ClampAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800258 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000259
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800260 self.ints.append((a.AddMinInt, minint))
261 self.ints.append((a.AddMaxInt, maxint))
Kevin Chengfea5a372021-10-11 18:38:47 +0000262
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800263 self.ints.append((a.AddMinFp, minfp))
264 self.ints.append((a.AddMaxFp, maxfp))
Kevin Chengfea5a372021-10-11 18:38:47 +0000265
266 def RescaleAttribute(
267 self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel
268 ):
269 from tosa import RescaleAttribute as a, Attribute
270
271 self.utype = Attribute.Attribute().RescaleAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800272 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000273
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800274 self.ints.append((a.AddInputZp, input_zp))
275 self.ints.append((a.AddOutputZp, output_zp))
276 self.intvecs.append((a.AddMultiplier, multiplier))
277 self.intvecs.append((a.AddShift, shift))
278 self.bools.append((a.AddScale32, scale32))
279 self.bools.append((a.AddDoubleRound, double_round))
280 self.bools.append((a.AddPerChannel, per_channel))
Kevin Chengfea5a372021-10-11 18:38:47 +0000281
282 def MulAttribute(self, shift):
283 from tosa import MulAttribute as a, Attribute
284
285 self.utype = Attribute.Attribute().MulAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800286 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000287
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800288 self.ints.append((a.AddShift, shift))
Kevin Chengfea5a372021-10-11 18:38:47 +0000289
290 def ArithmeticRightShiftAttribute(self, round):
291 from tosa import ArithmeticRightShiftAttribute as a, Attribute
292
293 self.utype = Attribute.Attribute().ArithmeticRightShiftAttribute
294 self.optFcns = (
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800295 a.Start,
296 a.End,
Kevin Chengfea5a372021-10-11 18:38:47 +0000297 )
298
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800299 self.bools.append((a.AddRound, round))
Kevin Chengfea5a372021-10-11 18:38:47 +0000300
Kevin Chengfea5a372021-10-11 18:38:47 +0000301 def CondIfAttribute(self, then_branch, else_branch):
302 from tosa import CondIfAttribute as a, Attribute
303
304 self.utype = Attribute.Attribute().CondIfAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800305 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000306
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800307 self.strings.append((a.AddThenBranch, then_branch))
308 self.strings.append((a.AddElseBranch, else_branch))
Kevin Chengfea5a372021-10-11 18:38:47 +0000309
310 def WhileLoopAttribute(self, cond_branch, body_branch):
311 from tosa import WhileLoopAttribute as a, Attribute
312
313 self.utype = Attribute.Attribute().WhileLoopAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800314 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000315
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800316 self.strings.append((a.AddCondBranch, cond_branch))
317 self.strings.append((a.AddBodyBranch, body_branch))
Kevin Chengfea5a372021-10-11 18:38:47 +0000318
TatWai Chong7be71652022-05-10 17:26:20 -0700319 def TransposeAttribute(self, perms):
Kevin Cheng38d214c2021-10-15 15:49:19 -0700320 from tosa import TransposeAttribute as a, Attribute
321
322 self.utype = Attribute.Attribute().TransposeAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800323 self.optFcns = (a.Start, a.End)
Kevin Cheng38d214c2021-10-15 15:49:19 -0700324
TatWai Chong7be71652022-05-10 17:26:20 -0700325 self.intvecs.append((a.AddPerms, perms))
Kevin Cheng38d214c2021-10-15 15:49:19 -0700326
327 def TableAttribute(self, table):
328 from tosa import TableAttribute as a, Attribute
329
330 self.utype = Attribute.Attribute().TableAttribute
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800331 self.optFcns = (a.Start, a.End)
Kevin Cheng38d214c2021-10-15 15:49:19 -0700332
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800333 self.intvecs.append((a.AddTable, table))
Kevin Chengfea5a372021-10-11 18:38:47 +0000334
James Ward485a11d2022-08-05 13:48:37 +0100335 def MatMulAttribute(self, A_zp, B_zp, accum_dtype):
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000336 from tosa import MatMulAttribute as a, Attribute
Jeremy Johnson9b225172021-12-14 16:34:47 +0000337
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000338 self.utype = Attribute.Attribute().MatMulAttribute
339 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000340
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000341 self.ints.append((a.AddAZp, A_zp))
342 self.ints.append((a.AddBZp, B_zp))
James Ward485a11d2022-08-05 13:48:37 +0100343 self.ints.append((a.AddAccumDtype, accum_dtype))
Kevin Chengfea5a372021-10-11 18:38:47 +0000344
James Ward485a11d2022-08-05 13:48:37 +0100345 def FullyConnectedAttribute(self, input_zp, weight_zp, accum_dtype):
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000346 from tosa import FullyConnectedAttribute as a, Attribute
Kevin Chengfea5a372021-10-11 18:38:47 +0000347
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000348 self.utype = Attribute.Attribute().FullyConnectedAttribute
349 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000350
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000351 self.ints.append((a.AddInputZp, input_zp))
352 self.ints.append((a.AddWeightZp, weight_zp))
James Ward485a11d2022-08-05 13:48:37 +0100353 self.ints.append((a.AddAccumDtype, accum_dtype))
Kevin Chengfea5a372021-10-11 18:38:47 +0000354
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000355 def NegateAttribute(self, input1_zp, output_zp):
356 from tosa import NegateAttribute as a, Attribute
Kevin Chengfea5a372021-10-11 18:38:47 +0000357
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000358 self.utype = Attribute.Attribute().NegateAttribute
359 self.optFcns = (a.Start, a.End)
Kevin Chengfea5a372021-10-11 18:38:47 +0000360
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000361 self.ints.append((a.AddInput1Zp, input1_zp))
362 self.ints.append((a.AddOutputZp, output_zp))
Kevin Chengfea5a372021-10-11 18:38:47 +0000363
364
365class TosaSerializerTensor:
366 def __init__(
367 self,
368 name,
369 shape,
370 dtype,
371 data=None,
372 placeholderFilename=None,
373 ):
374 self.name = name
375
376 if isinstance(shape, np.ndarray):
377 shape = shape.astype(int).tolist()
378 shape = list(map(int, shape))
379
380 self.shape = shape
381 self.dtype = dtype
382
James Ward34a62792022-10-18 17:27:40 +0100383 if dtype == DType.FP32 or dtype == DType.BF16:
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100384 fntype = np.float32
James Ward485a11d2022-08-05 13:48:37 +0100385 elif dtype == DType.FP16:
386 fntype = np.float16
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100387 else:
388 fntype = int
389
Kevin Chengfea5a372021-10-11 18:38:47 +0000390 if isinstance(data, np.ndarray):
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100391 data = data.flatten().astype(fntype).tolist()
392 data = list(map(fntype, data))
Kevin Chengfea5a372021-10-11 18:38:47 +0000393 self.data = data
394 elif isinstance(data, list):
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100395 data = list(map(fntype, data))
Kevin Chengfea5a372021-10-11 18:38:47 +0000396 self.data = data
397 else:
398 self.data = None
399
400 # Filename for placeholder tensors. These get generated by the test generation
Jeremy Johnson9b225172021-12-14 16:34:47 +0000401 # process and are written to disk, but are considered input tensors by the
402 # network so they do not appear in the TOSA serialiazation. However, if we
403 # want to form a unit test around these input tensors, we can get the filename
404 # from here.
Kevin Chengfea5a372021-10-11 18:38:47 +0000405 self.placeholderFilename = placeholderFilename
406
407 def __str__(self):
Jerry Ge1eb85042023-01-06 14:19:14 -0800408 concatString = "TosaSerializerTensor name: {} shape: {} dtype: {}".format(
Kevin Chengfea5a372021-10-11 18:38:47 +0000409 self.name,
410 self.shape,
411 DTypeNames[self.dtype],
412 )
Jerry Ge1eb85042023-01-06 14:19:14 -0800413 return concatString
Kevin Chengfea5a372021-10-11 18:38:47 +0000414
415 def setDtype(self, dtype):
416 self.dtype = dtype
417
418 def serialize(self, builder):
419 fb_name = builder.CreateString(self.name)
420 fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
421 if self.data:
422 u8_data = list()
423 # little endianess
424 if self.dtype == DType.BOOL:
425 for val in self.data:
426 val_u8 = np.uint8(val)
427 u8_data.append(val_u8)
428 elif self.dtype == DType.INT4:
429 in_size = len(self.data)
430 out_size = (in_size + 1) // 2
431 for i in range(out_size):
432 val_0 = self.data[2 * i]
433 if (2 * i + 1) < in_size:
434 val_1 = self.data[2 * i + 1]
435 else:
436 val_1 = 0
437 val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4)
438 val_u8 = np.uint8(val_i8)
439 u8_data.append(val_u8)
440 elif self.dtype == DType.INT8:
441 for val in self.data:
442 val_u8 = np.uint8(val)
443 u8_data.append(val_u8)
444 elif self.dtype == DType.INT16:
445 for val in self.data:
446 val_u16 = np.uint16(val)
447 b0 = val_u16 & ByteMask
448 b1 = (val_u16 >> np.uint16(8)) & ByteMask
449 u8_data.extend([b0, b1])
450 elif self.dtype == DType.INT32:
451 for val in self.data:
452 val_u32 = np.uint32(val)
453 b0 = val_u32 & ByteMask
454 b1 = (val_u32 >> np.uint32(8)) & ByteMask
455 b2 = (val_u32 >> np.uint32(16)) & ByteMask
Kevin Cheng6b078ca2021-10-13 23:12:50 -0700456 b3 = (val_u32 >> np.uint32(24)) & ByteMask
Kevin Chengfea5a372021-10-11 18:38:47 +0000457 u8_data.extend([b0, b1, b2, b3])
458 elif self.dtype == DType.INT48:
459 for val in self.data:
460 val_u64 = np.uint64(val)
461 b0 = val_u64 & ByteMask
462 b1 = (val_u64 >> np.uint64(8)) & ByteMask
463 b2 = (val_u64 >> np.uint64(16)) & ByteMask
464 b3 = (val_u64 >> np.uint64(24)) & ByteMask
465 b4 = (val_u64 >> np.uint64(32)) & ByteMask
466 b5 = (val_u64 >> np.uint64(40)) & ByteMask
467 u8_data.extend([b0, b1, b2, b3, b4, b5])
James Ward485a11d2022-08-05 13:48:37 +0100468 elif self.dtype == DType.FP16:
469 np_arr = np.array(self.data, dtype=np.float16)
470 u8_data.extend(np_arr.view(np.uint8))
James Ward34a62792022-10-18 17:27:40 +0100471 elif self.dtype == DType.FP32 or self.dtype == DType.BF16:
Kevin Chengfea5a372021-10-11 18:38:47 +0000472 for val in self.data:
473 b = struct.pack("!f", val)
474 u8_data.extend([b[3], b[2], b[1], b[0]])
James Ward485a11d2022-08-05 13:48:37 +0100475 elif self.dtype == TosaDType.DType:
476 # Serialize DType enum data as uint8 bytes
477 for val in self.data:
478 np_arr = np.array(self.data, dtype=np.uint32)
479 u8_data.extend(np_arr.view(np.uint8))
Kevin Chengfea5a372021-10-11 18:38:47 +0000480 else:
481 raise Exception(
482 "unsupported data type {}".format(DTypeNames[self.dtype])
483 )
484 fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data)
485
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800486 TosaTensor.Start(builder)
487 TosaTensor.AddName(builder, fb_name)
488 TosaTensor.AddShape(builder, fb_shapes)
489 TosaTensor.AddType(builder, self.dtype)
Kevin Chengfea5a372021-10-11 18:38:47 +0000490 if self.data:
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800491 TosaTensor.AddData(builder, fb_data)
Kevin Chengfea5a372021-10-11 18:38:47 +0000492
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800493 return TosaTensor.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000494
495
496class TosaSerializerOperator:
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000497 def __init__(self, op, inputs, outputs, attributes=None):
Kevin Chengfea5a372021-10-11 18:38:47 +0000498 self.op = op
499 self.attributes = attributes
500 self.inputs = TosaSerializer.toList(inputs)
501 self.outputs = TosaSerializer.toList(outputs)
Kevin Chengfea5a372021-10-11 18:38:47 +0000502
503 def __str__(self):
Jerry Ge1eb85042023-01-06 14:19:14 -0800504 concatString = "Op {}\n----\n".format(self.op)
Kevin Chengfea5a372021-10-11 18:38:47 +0000505
506 for i in self.inputs:
Jerry Ge1eb85042023-01-06 14:19:14 -0800507 concatString = concatString + " Input: {}\n".format(i)
Kevin Chengfea5a372021-10-11 18:38:47 +0000508 for o in self.outputs:
Jerry Ge1eb85042023-01-06 14:19:14 -0800509 concatString = concatString + " Output: {}\n".format(o)
Kevin Chengfea5a372021-10-11 18:38:47 +0000510
Jerry Ge1eb85042023-01-06 14:19:14 -0800511 return concatString
Kevin Chengfea5a372021-10-11 18:38:47 +0000512
513 def serialize(self, builder):
514 fb_inputs = TosaSerializer.serializeStrVec(
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800515 builder, self.inputs, TosaOperator.StartInputsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000516 )
517 fb_outputs = TosaSerializer.serializeStrVec(
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800518 builder, self.outputs, TosaOperator.StartOutputsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000519 )
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000520 # Need to serialize attributes enums still
Kevin Chengfea5a372021-10-11 18:38:47 +0000521 if self.attributes is not None:
522 fb_attributes = self.attributes.serialize(builder)
523
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800524 TosaOperator.Start(builder)
525 TosaOperator.AddOp(builder, self.op)
526 TosaOperator.AddInputs(builder, fb_inputs)
527 TosaOperator.AddOutputs(builder, fb_outputs)
Kevin Chengfea5a372021-10-11 18:38:47 +0000528 if self.attributes is not None:
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800529 TosaOperator.AddAttributeType(builder, self.attributes.utype)
530 TosaOperator.AddAttribute(builder, fb_attributes)
Kevin Chengfea5a372021-10-11 18:38:47 +0000531
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800532 return TosaOperator.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000533
534
535class TosaSerializerBasicBlock:
Jerry Ge1eb85042023-01-06 14:19:14 -0800536 def __init__(self, name, pathPrefix, saveConstsToFile=False):
Kevin Chengfea5a372021-10-11 18:38:47 +0000537 self.name = name
Jerry Ge1eb85042023-01-06 14:19:14 -0800538 self.pathPrefix = pathPrefix
Kevin Chengfea5a372021-10-11 18:38:47 +0000539 self.operators = []
Jerry Ge1eb85042023-01-06 14:19:14 -0800540 self.saveConstsToFile = saveConstsToFile
Kevin Chengfea5a372021-10-11 18:38:47 +0000541
542 # Dict assures uniqueness, but allows us to look up by name
543 self.tensors = dict()
544
545 self.inputs = []
546 self.outputs = []
547
548 def addTensor(
549 self,
550 name,
551 shape,
552 dtype,
553 data=None,
554 placeholderFilename=None,
555 ):
Jeremy Johnson9b225172021-12-14 16:34:47 +0000556 if name not in self.tensors:
Kevin Chengfea5a372021-10-11 18:38:47 +0000557 self.tensors[name] = TosaSerializerTensor(
558 name, shape, dtype, data, placeholderFilename
559 )
560
561 return self.tensors[name]
562
563 def addInput(self, name):
564 self.inputs.append(name)
565
566 def addOutput(self, name):
567 self.outputs.append(name)
568
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000569 def addOperator(self, op, inputs, outputs, attributes=None):
570 self.operators.append(TosaSerializerOperator(op, inputs, outputs, attributes))
Kevin Chengfea5a372021-10-11 18:38:47 +0000571
572 def serialize(self, builder):
573 fb_name = builder.CreateString(self.name)
574 fbv_inputs = TosaSerializer.serializeStrVec(
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800575 builder, list(self.inputs), TosaBasicBlock.StartInputsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000576 )
577 fbv_outputs = TosaSerializer.serializeStrVec(
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800578 builder, list(self.outputs), TosaBasicBlock.StartOutputsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000579 )
580 fbv_tensors = TosaSerializer.serializeObjVec(
581 builder,
582 list(self.tensors.values()),
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800583 TosaBasicBlock.StartTensorsVector,
Kevin Chengfea5a372021-10-11 18:38:47 +0000584 )
585 fbv_operators = TosaSerializer.serializeObjVec(
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800586 builder, self.operators, TosaBasicBlock.StartOperatorsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000587 )
588
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800589 TosaBasicBlock.Start(builder)
590 TosaBasicBlock.AddName(builder, fb_name)
591 TosaBasicBlock.AddInputs(builder, fbv_inputs)
592 TosaBasicBlock.AddOutputs(builder, fbv_outputs)
593 TosaBasicBlock.AddTensors(builder, fbv_tensors)
594 TosaBasicBlock.AddOperators(builder, fbv_operators)
595 return TosaBasicBlock.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000596
597
Jerry Ge1eb85042023-01-06 14:19:14 -0800598class TosaSerializerRegion:
599 def __init__(self, name, pathPrefix, saveConstsToFile=False):
600 self.name = name
Kevin Chengfea5a372021-10-11 18:38:47 +0000601 self.basicBlocks = []
Kevin Chengfea5a372021-10-11 18:38:47 +0000602 self.currInputIdx = 0
603 self.currConstIdx = 0
604 self.currLayerIdx = 1
605 self.currResultIdx = 0
Jerry Ge1eb85042023-01-06 14:19:14 -0800606 self.pathPrefix = pathPrefix
607 self.saveConstsToFile = saveConstsToFile
Kevin Chengfea5a372021-10-11 18:38:47 +0000608
Jerry Ge1eb85042023-01-06 14:19:14 -0800609 def addBasicBlock(self, name, pathPrefix, saveConstsToFile):
610 self.currBasicBlock = TosaSerializerBasicBlock(
611 name, pathPrefix, saveConstsToFile
612 )
613 self.basicBlocks.append(self.currBasicBlock)
Kevin Chengfea5a372021-10-11 18:38:47 +0000614
Jerry Ge1eb85042023-01-06 14:19:14 -0800615 def serialize(self, builder):
616 fb_name = builder.CreateString(self.name)
617 fbv_basicBlocks = TosaSerializer.serializeObjVec(
618 builder, self.basicBlocks, TosaRegion.StartBlocksVector
619 )
620
621 TosaRegion.Start(builder)
622 TosaRegion.AddName(builder, fb_name)
623 TosaRegion.AddBlocks(builder, fbv_basicBlocks)
624 return TosaRegion.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000625
626 def addPlaceholder(self, shape, dtype, vals):
627 if not self.currBasicBlock:
628 raise Exception("addTensor called without valid basic block")
629
630 name = "input-{}".format(self.currInputIdx)
631 filename = "{}.npy".format(name)
632 self.currInputIdx = self.currInputIdx + 1
633
634 tens = self.currBasicBlock.addTensor(name, shape, dtype, None, filename)
635 # This is always an input to the block
636 self.currBasicBlock.addInput(name)
637
638 if vals is not None:
639 np.save(os.path.join(self.pathPrefix, filename), vals, False)
640
641 return tens
642
643 def addConst(self, shape, dtype, vals):
644 if not self.currBasicBlock:
645 raise Exception("addTensor called without valid basic block")
646
647 name = "const-{}".format(self.currInputIdx)
Kevin Chengfea5a372021-10-11 18:38:47 +0000648 self.currInputIdx = self.currInputIdx + 1
649
650 tens = self.currBasicBlock.addTensor(name, shape, dtype, vals)
651 # Add the operator now
Jeremy Johnson9b225172021-12-14 16:34:47 +0000652 self.currBasicBlock.addOperator(TosaOp.Op().CONST, [], name)
Kevin Chengfea5a372021-10-11 18:38:47 +0000653
Jeremy Johnsonc92710d2022-09-15 12:16:07 +0100654 if self.saveConstsToFile:
655 filename = "{}.npy".format(name)
656 np.save(os.path.join(self.pathPrefix, filename), vals, False)
657
Kevin Chengfea5a372021-10-11 18:38:47 +0000658 return tens
659
660 def addIntermediate(self, shape, dtype):
Kevin Chengfea5a372021-10-11 18:38:47 +0000661 if not self.currBasicBlock:
662 raise Exception("addTensor called without valid basic block")
663
664 name = "layer-{}".format(self.currLayerIdx)
665 self.currLayerIdx = self.currLayerIdx + 1
666
667 tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
668
669 return tens
670
671 def addInputTensor(self, tensor):
672 self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype)
673 self.currBasicBlock.addInput(tensor.name)
674
675 def addOutputTensor(self, tensor):
676 self.currBasicBlock.addOutput(tensor.name)
677
678 def addOutput(self, shape, dtype):
679 if not self.currBasicBlock:
680 raise Exception("addTensor called without valid basic block")
681
682 name = "result-{}".format(self.currResultIdx)
683 self.currResultIdx = self.currResultIdx + 1
684
685 tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
686 self.currBasicBlock.addOutput(name)
687 return tens
688
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000689 def addOperator(self, op, inputs, outputs, attributes=None):
Jeremy Johnson9b225172021-12-14 16:34:47 +0000690 if op == TosaOp.Op().CONST:
Kevin Chengfea5a372021-10-11 18:38:47 +0000691 raise Exception("Use addConstTensor() to add CONST ops")
692
693 return self.currBasicBlock.addOperator(
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000694 op,
695 inputs,
696 outputs,
697 attributes,
Kevin Chengfea5a372021-10-11 18:38:47 +0000698 )
699
Jerry Ge1eb85042023-01-06 14:19:14 -0800700
701@unique
702class TensorDir(IntEnum):
703 PLACEHOLDER = 0
704 CONST = 1
705 INTERMEDIATE = 2
706 RESULT = 3
707
708
709class TosaSerializer:
710 def __init__(self, pathPrefix, saveConstsToFile=False):
Jerry Ge1eb85042023-01-06 14:19:14 -0800711 self.builder = flatbuffers.Builder(0)
712
713 self.regions = []
714 self.startRegion("main", pathPrefix, saveConstsToFile)
715
716 # Enables inspection of constant data outside of graph
717 self.saveConstsToFile = saveConstsToFile
718
719 self.currRegion.addBasicBlock("main", pathPrefix, self.saveConstsToFile)
720
721 # Is this an illegal test that is expected to fail?
722 self.expectedReturnCode = 0
723 self.expectedFailure = False
724 self.expectedFailureDesc = ""
725
726 def __str__(self):
727 concatString = ""
728 for region in self.regions:
729 concatString = concatString + str(region)
730 return concatString
731
732 def addPlaceholder(self, shape, dtype, vals):
733 return self.currRegion.addPlaceholder(shape, dtype, vals)
734
735 def addConst(self, shape, dtype, vals):
736 return self.currRegion.addConst(shape, dtype, vals)
737
738 def addIntermediate(self, shape, dtype):
739 return self.currRegion.addIntermediate(shape, dtype)
740
741 def addInputTensor(self, tensor):
742 self.currRegion.addInputTensor(tensor)
743
744 def addOutputTensor(self, tensor):
745 self.currRegion.addOutputTensor(tensor)
746
747 def addOutput(self, shape, dtype):
748 return self.currRegion.addOutput(shape, dtype)
749
750 def addOperator(self, op, inputs, outputs, attributes=None):
751 return self.currRegion.addOperator(op, inputs, outputs, attributes)
752
Jeremy Johnson9b225172021-12-14 16:34:47 +0000753 def setExpectedReturnCode(self, val, fail, desc=""):
Kevin Chengfea5a372021-10-11 18:38:47 +0000754
755 self.expectedReturnCode = val
756 self.expectedFailureDesc = desc
Jeremy Johnson9b225172021-12-14 16:34:47 +0000757 self.expectedFailure = fail
Kevin Chengfea5a372021-10-11 18:38:47 +0000758
759 def serialize(self):
760
761 builder = self.builder
762
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800763 Version.Start(builder)
764 Version.Add_major(builder, TOSA_VERSION[0])
765 Version.Add_minor(builder, TOSA_VERSION[1])
766 Version.Add_patch(builder, TOSA_VERSION[2])
767 Version.Add_draft(builder, TOSA_VERSION[3])
768 version = Version.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000769
Jerry Ge1eb85042023-01-06 14:19:14 -0800770 fbv_region = TosaSerializer.serializeObjVec(
771 builder, self.regions, TosaGraph.StartRegionsVector
Kevin Chengfea5a372021-10-11 18:38:47 +0000772 )
773
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800774 TosaGraph.Start(builder)
775 TosaGraph.AddVersion(builder, version)
Jerry Ge1eb85042023-01-06 14:19:14 -0800776 TosaGraph.AddRegions(builder, fbv_region)
Kevin Cheng49faa4e2021-11-08 16:59:18 -0800777 graph = TosaGraph.End(builder)
Kevin Chengfea5a372021-10-11 18:38:47 +0000778
Eric Kunzee6596402022-06-09 21:27:36 +0000779 self.builder.Finish(graph, TOSA_GRAPH_IDENTIFIER)
Kevin Chengfea5a372021-10-11 18:38:47 +0000780 return self.builder.Output()
781
782 def writeJson(self, tosa_filename):
783 """Write a json test file so that it is fairly easy to pick up the test
784 and generate commands for third party tool"""
785 test_desc = dict()
786
787 test_desc["tosa_file"] = tosa_filename
788 ifm_name = []
789 ifm_file = []
790 ofm_name = []
791 ofm_file = []
792
Jerry Ge1eb85042023-01-06 14:19:14 -0800793 for region in self.regions:
794 for block in region.basicBlocks:
795 if block:
796 for i in block.inputs:
797 ifm_name.append(i)
798 ifm_file.append(block.tensors[i].placeholderFilename)
799 for o in block.outputs:
800 ofm_name.append(o)
801 # Make up an OFM filename here. One isn't generated until the
802 # reference tool is run, so any name is a good name
803 ofm_file.append("ref-{}.npy".format(o))
Kevin Chengfea5a372021-10-11 18:38:47 +0000804
805 test_desc["ifm_name"] = ifm_name
806 test_desc["ifm_file"] = ifm_file
807 test_desc["ofm_name"] = ofm_name
808 test_desc["ofm_file"] = ofm_file
809 test_desc["expected_return_code"] = self.expectedReturnCode
810 test_desc["expected_failure"] = self.expectedFailure
811 if self.expectedFailureDesc:
812 test_desc["expected_failure_desc"] = self.expectedFailureDesc
813
814 return json.dumps(test_desc, indent=" ")
815
Jerry Ge1eb85042023-01-06 14:19:14 -0800816 def startRegion(self, name, pathPrefix, saveConstsToFile):
817 self.currRegion = TosaSerializerRegion(name, pathPrefix, saveConstsToFile)
818 self.regions.append(self.currRegion)
Kevin Chengfea5a372021-10-11 18:38:47 +0000819
820 @staticmethod
821 def serializeStrVec(builder, vec, start_fcn):
822 fb_strs = [builder.CreateString(i) for i in vec]
823 start_fcn(builder, len(fb_strs))
824 for s in fb_strs[::-1]:
825 builder.PrependUOffsetTRelative(s)
Eric Kunzeae906de2022-05-30 22:40:47 -0700826 try:
827 return builder.EndVector()
828 except TypeError:
829 return builder.EndVector(len(vec))
Kevin Chengfea5a372021-10-11 18:38:47 +0000830
831 @staticmethod
832 def serializeUint8Vec(builder, vec):
833 builder.StartVector(1, len(vec), 8)
834 for v in vec[::-1]:
835 builder.PrependUint8(v)
836 try:
837 return builder.EndVector()
838 except TypeError:
839 return builder.EndVector(len(vec))
840
841 @staticmethod
TatWai Chong49b1ca62022-06-10 01:49:13 -0700842 def serializeInt16Vec(builder, vec):
843 builder.StartVector(2, len(vec), 4)
844 for v in vec[::-1]:
845 builder.PrependInt16(v)
846 try:
847 return builder.EndVector()
848 except TypeError:
849 return builder.EndVector(len(vec))
850
851 @staticmethod
Kevin Chengfea5a372021-10-11 18:38:47 +0000852 def serializeInt32Vec(builder, vec):
853 builder.StartVector(4, len(vec), 4)
854 for v in vec[::-1]:
855 builder.PrependInt32(v)
856 try:
857 return builder.EndVector()
858 except TypeError:
859 return builder.EndVector(len(vec))
860
861 @staticmethod
862 def serializeFpVec(builder, vec):
863 builder.StartVector(4, len(vec), 4)
864 for v in vec[::-1]:
865 builder.PrependFloat32(v)
866 try:
867 return builder.EndVector()
868 except TypeError:
869 return builder.EndVector(len(vec))
870
871 @staticmethod
872 def serializeObjVec(builder, vec, start_fcn):
873 serialized_vec = []
874 for v in vec[::-1]:
875 serialized_vec.append(v.serialize(builder))
876
877 start_fcn(builder, len(vec))
878 for v in serialized_vec:
879 builder.PrependUOffsetTRelative(v)
880 try:
881 return builder.EndVector()
882 except TypeError:
883 return builder.EndVector(len(vec))
884
885 @staticmethod
886 def toList(val):
887 if isinstance(val, list):
888 return val
889 else:
890 return [val]