blob: d9c3633346e48f6752bc87439900bfc7879150b9 [file] [log] [blame]
Kevin Cheng3a478572021-01-22 17:21:02 -08001# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07002#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15#!/usr/bin/env python3
16
Kevin Cheng550ccc52021-03-03 11:21:43 -080017import os
18import sys
19import json
Eric Kunzee5e26762020-10-13 16:11:07 -070020import flatbuffers
21import numpy as np
Kevin Cheng82507d72021-06-17 16:01:59 -070022import struct
Eric Kunzee5e26762020-10-13 16:11:07 -070023from enum import Enum, IntEnum, unique
Kevin Cheng550ccc52021-03-03 11:21:43 -080024from tosa import (
25 TosaGraph,
26 TosaBasicBlock,
27 TosaTensor,
28 TosaOperator,
29 DType,
30 Op,
31 ResizeMode,
32 Version,
33)
Kevin Chengacb550f2021-06-29 15:32:19 -070034from tosa_ref_run import TosaReturnCode
Kevin Cheng550ccc52021-03-03 11:21:43 -080035
36# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
37parent_dir = os.path.dirname(os.path.realpath(__file__))
38sys.path.append(
39 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
40)
Eric Kunzee5e26762020-10-13 16:11:07 -070041import tosa
Eric Kunzee5e26762020-10-13 16:11:07 -070042
43# With the way flatc generates its python types, there is no programatic way
44# to get string names for the integer types. Manually maintain a string table
45# here.
Kevin Cheng82507d72021-06-17 16:01:59 -070046DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080047DTypeNames = [
48 "UNKNOWN",
49 "BOOL",
50 "UINT8",
51 "INT4",
52 "INT8",
53 "INT16",
54 "INT32",
55 "INT48",
56 "FLOAT",
57]
58
Kevin Cheng82507d72021-06-17 16:01:59 -070059ByteMask = np.uint64(0xFF)
Eric Kunzee5e26762020-10-13 16:11:07 -070060
Kevin Chengacb550f2021-06-29 15:32:19 -070061
Eric Kunzee5e26762020-10-13 16:11:07 -070062def dtype_str_to_val(name):
63
64 for i in range(len(DTypeNames)):
65 if name.casefold() == DTypeNames[i].casefold():
66 return i
Kevin Cheng550ccc52021-03-03 11:21:43 -080067 raise Exception("Unable to parse DType name {}".format(name))
Eric Kunzee5e26762020-10-13 16:11:07 -070068
69
70class TosaSerializerUnion:
Kevin Cheng550ccc52021-03-03 11:21:43 -080071 """This class handles encapsulating and serializing union types into flatbuffers"""
72
Eric Kunzee5e26762020-10-13 16:11:07 -070073 def __init__(self):
74
75 # A tuple of the start and end functions. Set by the options constructors below
76 self.optFcns = None
77
78 # The type from the tosa.Options enumeration. Set by the options constructors below.
79 self.utype = None
80
81 # Each of these lists is a tuple of the add function and the
82 # value being added. Set by the options constructors below.
83 self.ints = []
84 self.bools = []
85 self.floats = []
86 self.strings = []
87 self.intvecs = []
Kevin Cheng77d0f762020-11-24 10:26:32 -080088 self.fpvecs = []
Eric Kunzee5e26762020-10-13 16:11:07 -070089
90 def serialize(self, builder):
91
92 # We have to build strings and vectors first
93 strList = []
94 intVecList = []
Kevin Cheng77d0f762020-11-24 10:26:32 -080095 fpVecList = []
Eric Kunzee5e26762020-10-13 16:11:07 -070096
97 for fcn, val in self.strings:
98 strList.append((fcn, builder.CreateString(val)))
99
100 for fcn, val in self.intvecs:
101 intVecList.append((fcn, TosaSerializer.serializeInt32Vec(builder, val)))
102
Kevin Cheng77d0f762020-11-24 10:26:32 -0800103 for fcn, val in self.fpvecs:
104 fpVecList.append((fcn, TosaSerializer.serializeFpVec(builder, val)))
105
Eric Kunzee5e26762020-10-13 16:11:07 -0700106 startFcn, endFcn = self.optFcns
107
108 # Then serialize the options object from the list of primitives and
109 # other serialized values
110 startFcn(builder)
111 for fcn, val in self.ints:
112 fcn(builder, val)
113
114 for fcn, val in self.bools:
115 fcn(builder, val)
116
117 for fcn, val in self.floats:
118 fcn(builder, val)
119
120 for fcn, val in strList:
121 fcn(builder, val)
122
123 for fcn, val in intVecList:
124 fcn(builder, val)
125
Kevin Cheng77d0f762020-11-24 10:26:32 -0800126 for fcn, val in fpVecList:
127 fcn(builder, val)
128
Eric Kunzee5e26762020-10-13 16:11:07 -0700129 return endFcn(builder)
130
Kevin Cheng550ccc52021-03-03 11:21:43 -0800131
Eric Kunzee5e26762020-10-13 16:11:07 -0700132class TosaSerializerAttribute(TosaSerializerUnion):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800133 """This class handles encapsulating all of the enumerated types for attributes"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700134
135 def __init__(self):
136 super().__init__()
137
Kevin Cheng93a16282021-08-31 16:14:03 -0700138 def PoolAttribute(self, kernel, stride, padding):
139 from tosa import PoolAttribute as a, Attribute
Eric Kunzee5e26762020-10-13 16:11:07 -0700140
Kevin Cheng93a16282021-08-31 16:14:03 -0700141 self.utype = Attribute.Attribute().PoolAttribute
Eric Kunzee5e26762020-10-13 16:11:07 -0700142
Kevin Cheng93a16282021-08-31 16:14:03 -0700143 self.optFcns = (a.PoolAttributeStart, a.PoolAttributeEnd)
144 self.intvecs.append((a.PoolAttributeAddPadding, padding))
145 self.intvecs.append((a.PoolAttributeAddKernel, kernel))
146 self.intvecs.append((a.PoolAttributeAddStride, stride))
Eric Kunzee5e26762020-10-13 16:11:07 -0700147
Kevin Cheng93a16282021-08-31 16:14:03 -0700148 def ConvAttribute(self, padding, stride, dilation):
149 from tosa import ConvAttribute as a, Attribute
Eric Kunzee5e26762020-10-13 16:11:07 -0700150
Kevin Cheng93a16282021-08-31 16:14:03 -0700151 self.utype = Attribute.Attribute().ConvAttribute
152 self.optFcns = (a.ConvAttributeStart, a.ConvAttributeEnd)
Eric Kunzee5e26762020-10-13 16:11:07 -0700153
Kevin Cheng93a16282021-08-31 16:14:03 -0700154 self.intvecs.append((a.ConvAttributeAddPadding, padding))
155 self.intvecs.append((a.ConvAttributeAddStride, stride))
156 self.intvecs.append((a.ConvAttributeAddDilation, dilation))
Eric Kunzee5e26762020-10-13 16:11:07 -0700157
Kevin Cheng93a16282021-08-31 16:14:03 -0700158 def TransposeConvAttribute(self, outpad, stride, dilation, output_shape):
159 from tosa import TransposeConvAttribute as a, Attribute
Eric Kunzee5e26762020-10-13 16:11:07 -0700160
Kevin Cheng93a16282021-08-31 16:14:03 -0700161 self.utype = Attribute.Attribute().TransposeConvAttribute
162 self.optFcns = (a.TransposeConvAttributeStart, a.TransposeConvAttributeEnd)
Eric Kunzee5e26762020-10-13 16:11:07 -0700163
Kevin Cheng93a16282021-08-31 16:14:03 -0700164 self.intvecs.append((a.TransposeConvAttributeAddOutpad, outpad))
165 self.intvecs.append((a.TransposeConvAttributeAddStride, stride))
166 self.intvecs.append((a.TransposeConvAttributeAddDilation, dilation))
167 self.intvecs.append((a.TransposeConvAttributeAddOutputShape, output_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700168
169 def ReluNAttribute(self, maxint, maxfp):
170 from tosa import ReluNAttribute as a, Attribute
171
172 self.utype = Attribute.Attribute().ReluNAttribute
173 self.optFcns = (a.ReluNAttributeStart, a.ReluNAttributeEnd)
174
175 self.ints.append((a.ReluNAttributeAddMaxInt, maxint))
176 self.ints.append((a.ReluNAttributeAddMaxFp, maxfp))
177
Eric Kunzee5e26762020-10-13 16:11:07 -0700178 def AxisAttribute(self, axis):
179 from tosa import AxisAttribute as a, Attribute
180
181 self.utype = Attribute.Attribute().AxisAttribute
182 self.optFcns = (a.AxisAttributeStart, a.AxisAttributeEnd)
183
Kevin Cheng550ccc52021-03-03 11:21:43 -0800184 self.ints.append((a.AxisAttributeAddAxis, axis))
Eric Kunzee5e26762020-10-13 16:11:07 -0700185
186 def ReshapeAttribute(self, shape):
187 from tosa import ReshapeAttribute as a, Attribute
188
189 self.utype = Attribute.Attribute().ReshapeAttribute
190 self.optFcns = (a.ReshapeAttributeStart, a.ReshapeAttributeEnd)
191
Kevin Cheng550ccc52021-03-03 11:21:43 -0800192 self.intvecs.append((a.ReshapeAttributeAddShape, shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700193
194 def SliceAttribute(self, begin, size):
195 from tosa import SliceAttribute as a, Attribute
196
197 self.utype = Attribute.Attribute().SliceAttribute
198 self.optFcns = (a.SliceAttributeStart, a.SliceAttributeEnd)
199
Kevin Cheng550ccc52021-03-03 11:21:43 -0800200 self.intvecs.append((a.SliceAttributeAddBegin, begin))
201 self.intvecs.append((a.SliceAttributeAddSize, size))
Eric Kunzee5e26762020-10-13 16:11:07 -0700202
203 def TileAttribute(self, multiples):
204 from tosa import TileAttribute as a, Attribute
205
206 self.utype = Attribute.Attribute().TileAttribute
207 self.optFcns = (a.TileAttributeStart, a.TileAttributeEnd)
208
Kevin Cheng550ccc52021-03-03 11:21:43 -0800209 self.intvecs.append((a.TileAttributeAddMultiples, multiples))
Eric Kunzee5e26762020-10-13 16:11:07 -0700210
Kevin Cheng550ccc52021-03-03 11:21:43 -0800211 def ResizeAttribute(
212 self, output_size, stride, offset, shift, stride_fp, offset_fp, mode
213 ):
Eric Kunzee5e26762020-10-13 16:11:07 -0700214 from tosa import ResizeAttribute as a, Attribute
215
216 self.utype = Attribute.Attribute().ResizeAttribute
217 self.optFcns = (a.ResizeAttributeStart, a.ResizeAttributeEnd)
218
Kevin Cheng550ccc52021-03-03 11:21:43 -0800219 self.intvecs.append((a.ResizeAttributeAddOutputSize, output_size))
220 self.intvecs.append((a.ResizeAttributeAddStride, stride))
221 self.intvecs.append((a.ResizeAttributeAddOffset, offset))
222 self.ints.append((a.ResizeAttributeAddShift, shift))
223 self.fpvecs.append((a.ResizeAttributeAddStrideFp, stride_fp))
224 self.fpvecs.append((a.ResizeAttributeAddOffsetFp, offset_fp))
225 self.ints.append((a.ResizeAttributeAddMode, mode))
Eric Kunzee5e26762020-10-13 16:11:07 -0700226
227 def ClampAttribute(self, minint, maxint, minfp, maxfp):
228 from tosa import ClampAttribute as a, Attribute
229
230 self.utype = Attribute.Attribute().ClampAttribute
231 self.optFcns = (a.ClampAttributeStart, a.ClampAttributeEnd)
232
Kevin Cheng550ccc52021-03-03 11:21:43 -0800233 self.ints.append((a.ClampAttributeAddMinInt, minint))
234 self.ints.append((a.ClampAttributeAddMaxInt, maxint))
Eric Kunzee5e26762020-10-13 16:11:07 -0700235
Kevin Cheng550ccc52021-03-03 11:21:43 -0800236 self.ints.append((a.ClampAttributeAddMinFp, minfp))
237 self.ints.append((a.ClampAttributeAddMaxFp, maxfp))
Eric Kunzee5e26762020-10-13 16:11:07 -0700238
Kevin Cheng550ccc52021-03-03 11:21:43 -0800239 def RescaleAttribute(
240 self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel
241 ):
Eric Kunzee5e26762020-10-13 16:11:07 -0700242 from tosa import RescaleAttribute as a, Attribute
243
244 self.utype = Attribute.Attribute().RescaleAttribute
245 self.optFcns = (a.RescaleAttributeStart, a.RescaleAttributeEnd)
246
Kevin Cheng550ccc52021-03-03 11:21:43 -0800247 self.ints.append((a.RescaleAttributeAddInputZp, input_zp))
248 self.ints.append((a.RescaleAttributeAddOutputZp, output_zp))
249 self.intvecs.append((a.RescaleAttributeAddMultiplier, multiplier))
250 self.intvecs.append((a.RescaleAttributeAddShift, shift))
251 self.bools.append((a.RescaleAttributeAddScale32, scale32))
252 self.bools.append((a.RescaleAttributeAddDoubleRound, double_round))
253 self.bools.append((a.RescaleAttributeAddPerChannel, per_channel))
Eric Kunzee5e26762020-10-13 16:11:07 -0700254
Kevin Chengaee1fac2020-11-11 13:54:06 -0800255 def MulAttribute(self, shift):
256 from tosa import MulAttribute as a, Attribute
257
258 self.utype = Attribute.Attribute().MulAttribute
259 self.optFcns = (a.MulAttributeStart, a.MulAttributeEnd)
260
Kevin Cheng550ccc52021-03-03 11:21:43 -0800261 self.ints.append((a.MulAttributeAddShift, shift))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800262
263 def ArithmeticRightShiftAttribute(self, round):
264 from tosa import ArithmeticRightShiftAttribute as a, Attribute
265
266 self.utype = Attribute.Attribute().ArithmeticRightShiftAttribute
Kevin Cheng550ccc52021-03-03 11:21:43 -0800267 self.optFcns = (
268 a.ArithmeticRightShiftAttributeStart,
269 a.ArithmeticRightShiftAttributeEnd,
270 )
Kevin Chengaee1fac2020-11-11 13:54:06 -0800271
Kevin Cheng550ccc52021-03-03 11:21:43 -0800272 self.bools.append((a.ArithmeticRightShiftAttributeAddRound, round))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800273
Eric Kunzee5e26762020-10-13 16:11:07 -0700274 def CustomAttribute(self, identifier):
275 from tosa import CustomAttribute as a, Attribute
276
277 self.utype = Attribute.Attribute().CustomAttribute
278 self.optFcns = (a.CustomAttributeStart, a.CustomAttributeEnd)
279
Kevin Cheng550ccc52021-03-03 11:21:43 -0800280 self.strings.append((a.CustomAttributeAddIdentifier, identifier))
Eric Kunzee5e26762020-10-13 16:11:07 -0700281
282 def CondIfAttribute(self, then_branch, else_branch):
283 from tosa import CondIfAttribute as a, Attribute
284
285 self.utype = Attribute.Attribute().CondIfAttribute
286 self.optFcns = (a.CondIfAttributeStart, a.CondIfAttributeEnd)
287
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 self.strings.append((a.CondIfAttributeAddThenBranch, then_branch))
289 self.strings.append((a.CondIfAttributeAddElseBranch, else_branch))
Eric Kunzee5e26762020-10-13 16:11:07 -0700290
291 def WhileLoopAttribute(self, cond_branch, body_branch):
292 from tosa import WhileLoopAttribute as a, Attribute
293
294 self.utype = Attribute.Attribute().WhileLoopAttribute
295 self.optFcns = (a.WhileLoopAttributeStart, a.WhileLoopAttributeEnd)
296
Kevin Cheng550ccc52021-03-03 11:21:43 -0800297 self.strings.append((a.WhileLoopAttributeAddCondBranch, cond_branch))
298 self.strings.append((a.WhileLoopAttributeAddBodyBranch, body_branch))
299
Eric Kunzee5e26762020-10-13 16:11:07 -0700300
301class TosaSerializerQuantInfo(TosaSerializerUnion):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800302 """This class handles encapsulating all of the enumerated types for quantinfo types"""
303
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 def __init__(self):
305 super().__init__()
306
307 def ConvQuantInfo(self, input_zp, weight_zp):
308 from tosa import ConvQuantInfo as q, QuantInfo
309
310 self.utype = QuantInfo.QuantInfo().ConvQuantInfo
311 self.optFcns = (q.ConvQuantInfoStart, q.ConvQuantInfoEnd)
312 self.ints.append((q.ConvQuantInfoAddInputZp, input_zp))
313 self.ints.append((q.ConvQuantInfoAddWeightZp, weight_zp))
314
315 def UnaryQuantInfo(self, input_zp, output_zp):
316 from tosa import UnaryQuantInfo as q, QuantInfo
317
318 self.utype = QuantInfo.QuantInfo().UnaryQuantInfo
319 self.optFcns = (q.UnaryQuantInfoStart, q.UnaryQuantInfoEnd)
320 self.ints.append((q.UnaryQuantInfoAddInputZp, input_zp))
321 self.ints.append((q.UnaryQuantInfoAddOutputZp, output_zp))
322
323 def MatMulQuantInfo(self, a_zp, b_zp):
324 from tosa import MatMulQuantInfo as q, QuantInfo
325
326 self.utype = QuantInfo.QuantInfo().MatMulQuantInfo
327 self.optFcns = (q.MatMulQuantInfoStart, q.MatMulQuantInfoEnd)
328 self.ints.append((q.MatMulQuantInfoAddAZp, a_zp))
329 self.ints.append((q.MatMulQuantInfoAddBZp, b_zp))
330
331 def PadQuantInfo(self, input_zp):
332 from tosa import PadQuantInfo as q, QuantInfo
333
334 self.utype = QuantInfo.QuantInfo().PadQuantInfo
335 self.optFcns = (q.PadQuantInfoStart, q.PadQuantInfoEnd)
336 self.ints.append((q.PadQuantInfoAddInputZp, input_zp))
337
Kevin Cheng550ccc52021-03-03 11:21:43 -0800338
Eric Kunzee5e26762020-10-13 16:11:07 -0700339class TosaSerializerTensor:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800340 def __init__(
341 self,
342 name,
343 shape,
344 dtype,
Kevin Cheng82507d72021-06-17 16:01:59 -0700345 data=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800346 placeholderFilename=None,
347 ):
Eric Kunzee5e26762020-10-13 16:11:07 -0700348 self.name = name
349
350 if isinstance(shape, np.ndarray):
351 shape = shape.astype(int).tolist()
352 shape = list(map(int, shape))
353
354 self.shape = shape
355 self.dtype = dtype
Eric Kunzee5e26762020-10-13 16:11:07 -0700356
Kevin Cheng82507d72021-06-17 16:01:59 -0700357 if isinstance(data, np.ndarray):
358 data = data.flatten().astype(int).tolist()
359 data = list(map(int, data))
360 self.data = data
Kevin Chengb227ae52021-09-02 13:43:17 -0700361 elif isinstance(data, list):
362 data = list(map(int, data))
363 self.data = data
Kevin Cheng82507d72021-06-17 16:01:59 -0700364 else:
365 self.data = None
Eric Kunzee5e26762020-10-13 16:11:07 -0700366
367 # Filename for placeholder tensors. These get generated by the test generation
368 # process and are written to disk, but are considered input tensors by the network
369 # so they do not appear in the TOSA serialiazation. However, if we want to form a unit
370 # test around these input tensors, we can get the filename from here.
371 self.placeholderFilename = placeholderFilename
372
373 def __str__(self):
Kevin Cheng82507d72021-06-17 16:01:59 -0700374 str = "TosaSerializerTensor name: {} shape: {} dtype: {}".format(
Kevin Cheng550ccc52021-03-03 11:21:43 -0800375 self.name,
376 self.shape,
377 DTypeNames[self.dtype],
Kevin Cheng550ccc52021-03-03 11:21:43 -0800378 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700379 return str
380
Eric Kunzee5e26762020-10-13 16:11:07 -0700381 def setDtype(self, dtype):
382 self.dtype = dtype
383
Eric Kunzee5e26762020-10-13 16:11:07 -0700384 def serialize(self, builder):
385 fb_name = builder.CreateString(self.name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700386 fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
Kevin Cheng82507d72021-06-17 16:01:59 -0700387 if self.data:
388 u8_data = list()
389 # little endianess
390 if self.dtype == DType.BOOL:
391 for val in self.data:
392 val_u8 = np.uint8(val)
393 u8_data.append(val_u8)
Kevin Chenga9017402021-07-28 17:19:23 -0700394 elif self.dtype == DType.INT4:
395 in_size = len(self.data)
396 out_size = (in_size + 1) // 2
397 for i in range(out_size):
398 val_0 = self.data[2 * i]
399 if (2 * i + 1) < in_size:
400 val_1 = self.data[2 * i + 1]
401 else:
402 val_1 = 0
403 val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4)
404 val_u8 = np.uint8(val_i8)
405 u8_data.append(val_u8)
Kevin Cheng82507d72021-06-17 16:01:59 -0700406 elif self.dtype == DType.INT8:
407 for val in self.data:
408 val_u8 = np.uint8(val)
409 u8_data.append(val_u8)
410 elif self.dtype == DType.INT16:
411 for val in self.data:
412 val_u16 = np.uint16(val)
413 b0 = val_u16 & ByteMask
414 b1 = (val_u16 >> np.uint16(8)) & ByteMask
415 u8_data.extend([b0, b1])
416 elif self.dtype == DType.INT32:
417 for val in self.data:
418 val_u32 = np.uint32(val)
419 b0 = val_u32 & ByteMask
420 b1 = (val_u32 >> np.uint32(8)) & ByteMask
421 b2 = (val_u32 >> np.uint32(16)) & ByteMask
422 b3 = (val_u32 >> np.uint32(32)) & ByteMask
423 u8_data.extend([b0, b1, b2, b3])
424 elif self.dtype == DType.INT48:
425 for val in self.data:
426 val_u64 = np.uint64(val)
427 b0 = val_u64 & ByteMask
428 b1 = (val_u64 >> np.uint64(8)) & ByteMask
429 b2 = (val_u64 >> np.uint64(16)) & ByteMask
430 b3 = (val_u64 >> np.uint64(24)) & ByteMask
431 b4 = (val_u64 >> np.uint64(32)) & ByteMask
432 b5 = (val_u64 >> np.uint64(40)) & ByteMask
433 u8_data.extend([b0, b1, b2, b3, b4, b5])
434 elif self.dtype == DType.FLOAT:
435 for val in self.data:
Kevin Chengacb550f2021-06-29 15:32:19 -0700436 b = struct.pack("!f", val)
Kevin Cheng82507d72021-06-17 16:01:59 -0700437 u8_data.extend([b[3], b[2], b[1], b[0]])
438 else:
Kevin Chengacb550f2021-06-29 15:32:19 -0700439 raise Exception(
440 "unsupported data type {}".format(DTypeNames[self.dtype])
441 )
Kevin Cheng82507d72021-06-17 16:01:59 -0700442 fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data)
Eric Kunzee5e26762020-10-13 16:11:07 -0700443
444 TosaTensor.TosaTensorStart(builder)
445 TosaTensor.TosaTensorAddName(builder, fb_name)
446 TosaTensor.TosaTensorAddShape(builder, fb_shapes)
447 TosaTensor.TosaTensorAddType(builder, self.dtype)
Kevin Cheng82507d72021-06-17 16:01:59 -0700448 if self.data:
449 TosaTensor.TosaTensorAddData(builder, fb_data)
Eric Kunzee5e26762020-10-13 16:11:07 -0700450
451 return TosaTensor.TosaTensorEnd(builder)
452
Kevin Cheng550ccc52021-03-03 11:21:43 -0800453
Eric Kunzee5e26762020-10-13 16:11:07 -0700454class TosaSerializerOperator:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800455 def __init__(self, op, inputs, outputs, attributes=None, quantInfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700456 self.op = op
457 self.attributes = attributes
458 self.inputs = TosaSerializer.toList(inputs)
459 self.outputs = TosaSerializer.toList(outputs)
460 self.quantInfo = quantInfo
461
462 def __str__(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800463 str = "Op {}\n----\n".format(self.op)
Eric Kunzee5e26762020-10-13 16:11:07 -0700464
465 for i in self.inputs:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800466 str = str + " Input: {}\n".format(i)
Eric Kunzee5e26762020-10-13 16:11:07 -0700467 for o in self.outputs:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800468 str = str + " Output: {}\n".format(o)
Eric Kunzee5e26762020-10-13 16:11:07 -0700469
470 return str
471
472 def serialize(self, builder):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800473 fb_inputs = TosaSerializer.serializeStrVec(
474 builder, self.inputs, TosaOperator.TosaOperatorStartInputsVector
475 )
476 fb_outputs = TosaSerializer.serializeStrVec(
477 builder, self.outputs, TosaOperator.TosaOperatorStartOutputsVector
478 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700479 # Need to serialize quant_info and attributes enums still
480 if self.attributes is not None:
481 fb_attributes = self.attributes.serialize(builder)
482
483 if self.quantInfo is not None:
484 fb_qinfo = self.quantInfo.serialize(builder)
485
486 TosaOperator.TosaOperatorStart(builder)
487 TosaOperator.TosaOperatorAddOp(builder, self.op)
488 TosaOperator.TosaOperatorAddInputs(builder, fb_inputs)
489 TosaOperator.TosaOperatorAddOutputs(builder, fb_outputs)
490 if self.attributes is not None:
491 TosaOperator.TosaOperatorAddAttributeType(builder, self.attributes.utype)
492 TosaOperator.TosaOperatorAddAttribute(builder, fb_attributes)
493 if self.quantInfo is not None:
494 TosaOperator.TosaOperatorAddQuantInfoType(builder, self.quantInfo.utype)
495 TosaOperator.TosaOperatorAddQuantInfo(builder, fb_qinfo)
496
497 return TosaOperator.TosaOperatorEnd(builder)
498
Kevin Cheng550ccc52021-03-03 11:21:43 -0800499
Eric Kunzee5e26762020-10-13 16:11:07 -0700500class TosaSerializerBasicBlock:
501 def __init__(self, name):
502 self.name = name
503 self.operators = []
504
505 # Dict assures uniqueness, but allows us to look up by name
506 self.tensors = dict()
507
508 self.inputs = []
509 self.outputs = []
510
Kevin Cheng550ccc52021-03-03 11:21:43 -0800511 def addTensor(
512 self,
513 name,
514 shape,
515 dtype,
Kevin Cheng82507d72021-06-17 16:01:59 -0700516 data=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800517 placeholderFilename=None,
518 ):
Eric Kunzee5e26762020-10-13 16:11:07 -0700519 try:
520 # Someone already added this tensor.
Eric Kunzee5e26762020-10-13 16:11:07 -0700521 tens = self.tensors[name]
Eric Kunzee5e26762020-10-13 16:11:07 -0700522 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800523 self.tensors[name] = TosaSerializerTensor(
Kevin Cheng82507d72021-06-17 16:01:59 -0700524 name, shape, dtype, data, placeholderFilename
Kevin Cheng550ccc52021-03-03 11:21:43 -0800525 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700526
527 return self.tensors[name]
528
529 def addInput(self, name):
530 self.inputs.append(name)
531
532 def addOutput(self, name):
533 self.outputs.append(name)
534
Kevin Cheng550ccc52021-03-03 11:21:43 -0800535 def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None):
536 self.operators.append(
537 TosaSerializerOperator(op, inputs, outputs, attributes, quant_info)
538 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700539
540 def serialize(self, builder):
541 fb_name = builder.CreateString(self.name)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800542 fbv_inputs = TosaSerializer.serializeStrVec(
543 builder, list(self.inputs), TosaBasicBlock.TosaBasicBlockStartInputsVector
544 )
545 fbv_outputs = TosaSerializer.serializeStrVec(
546 builder, list(self.outputs), TosaBasicBlock.TosaBasicBlockStartOutputsVector
547 )
548 fbv_tensors = TosaSerializer.serializeObjVec(
549 builder,
550 list(self.tensors.values()),
551 TosaBasicBlock.TosaBasicBlockStartTensorsVector,
552 )
553 fbv_operators = TosaSerializer.serializeObjVec(
554 builder, self.operators, TosaBasicBlock.TosaBasicBlockStartOperatorsVector
555 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700556
557 TosaBasicBlock.TosaBasicBlockStart(builder)
558 TosaBasicBlock.TosaBasicBlockAddName(builder, fb_name)
559 TosaBasicBlock.TosaBasicBlockAddInputs(builder, fbv_inputs)
560 TosaBasicBlock.TosaBasicBlockAddOutputs(builder, fbv_outputs)
561 TosaBasicBlock.TosaBasicBlockAddTensors(builder, fbv_tensors)
562 TosaBasicBlock.TosaBasicBlockAddOperators(builder, fbv_operators)
563 return TosaBasicBlock.TosaBasicBlockEnd(builder)
564
Kevin Cheng550ccc52021-03-03 11:21:43 -0800565
Eric Kunzee5e26762020-10-13 16:11:07 -0700566@unique
567class TensorDir(IntEnum):
568 PLACEHOLDER = 0
569 CONST = 1
570 INTERMEDIATE = 2
571 RESULT = 3
572
Kevin Cheng550ccc52021-03-03 11:21:43 -0800573
Eric Kunzee5e26762020-10-13 16:11:07 -0700574class TosaSerializer:
575 def __init__(self, pathPrefix):
576
577 # Get the global TOSA version if not already defined
578 try:
579 TOSA_VERSION
580 except NameError:
581 TosaSerializer.setTosaVersion()
582
583 self.builder = flatbuffers.Builder(0)
584
585 self.basicBlocks = []
Kevin Cheng550ccc52021-03-03 11:21:43 -0800586 self.startBasicBlock("main")
Eric Kunzee5e26762020-10-13 16:11:07 -0700587 self.pathPrefix = pathPrefix
588
589 # Indicies used for adding/naming tensors
590 self.currInputIdx = 0
591 self.currConstIdx = 0
592 self.currLayerIdx = 1
593 self.currResultIdx = 0
594
595 # Is this an illegal test that is expected to fail?
Kevin Chengacb550f2021-06-29 15:32:19 -0700596 self.expectedReturnCode = TosaReturnCode.VALID
Jared Smolensa9d53952021-08-24 23:48:19 +0000597 self.expectedFailure = False
Kevin Cheng550ccc52021-03-03 11:21:43 -0800598 self.expectedFailureDesc = ""
Eric Kunzee5e26762020-10-13 16:11:07 -0700599
600 def __str__(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800601 str = ""
Eric Kunzee5e26762020-10-13 16:11:07 -0700602 for bb in self.basicBlocks:
603 str = str + bb.__str__()
604 return str
605
Kevin Cheng550ccc52021-03-03 11:21:43 -0800606 def addPlaceholder(self, shape, dtype, vals):
Eric Kunzee5e26762020-10-13 16:11:07 -0700607 if not self.currBasicBlock:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800608 raise Exception("addTensor called without valid basic block")
Eric Kunzee5e26762020-10-13 16:11:07 -0700609
Kevin Cheng550ccc52021-03-03 11:21:43 -0800610 name = "input-{}".format(self.currInputIdx)
611 filename = "{}.npy".format(name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700612 self.currInputIdx = self.currInputIdx + 1
613
Kevin Cheng550ccc52021-03-03 11:21:43 -0800614 tens = self.currBasicBlock.addTensor(name, shape, dtype, None, filename)
Eric Kunzee5e26762020-10-13 16:11:07 -0700615 # This is always an input to the block
616 self.currBasicBlock.addInput(name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700617
618 if vals is not None:
619 np.save(os.path.join(self.pathPrefix, filename), vals, False)
620
621 return tens
622
Kevin Cheng550ccc52021-03-03 11:21:43 -0800623 def addConst(self, shape, dtype, vals):
Eric Kunzee5e26762020-10-13 16:11:07 -0700624 if not self.currBasicBlock:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800625 raise Exception("addTensor called without valid basic block")
Eric Kunzee5e26762020-10-13 16:11:07 -0700626
Kevin Cheng550ccc52021-03-03 11:21:43 -0800627 name = "const-{}".format(self.currInputIdx)
628 filename = "{}.npy".format(name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700629 self.currInputIdx = self.currInputIdx + 1
630
Kevin Cheng82507d72021-06-17 16:01:59 -0700631 tens = self.currBasicBlock.addTensor(name, shape, dtype, vals)
Eric Kunzee5e26762020-10-13 16:11:07 -0700632 # Add the operator now
633 self.currBasicBlock.addOperator(tosa.Op.Op().CONST, [], name)
634
Eric Kunzee5e26762020-10-13 16:11:07 -0700635 return tens
636
Kevin Cheng550ccc52021-03-03 11:21:43 -0800637 def addIntermediate(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -0700638
639 if not self.currBasicBlock:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800640 raise Exception("addTensor called without valid basic block")
Eric Kunzee5e26762020-10-13 16:11:07 -0700641
Kevin Cheng550ccc52021-03-03 11:21:43 -0800642 name = "layer-{}".format(self.currLayerIdx)
Eric Kunzee5e26762020-10-13 16:11:07 -0700643 self.currLayerIdx = self.currLayerIdx + 1
644
Kevin Cheng82507d72021-06-17 16:01:59 -0700645 tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
Eric Kunzee5e26762020-10-13 16:11:07 -0700646
647 return tens
648
649 def addInputTensor(self, tensor):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800650 self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -0700651 self.currBasicBlock.addInput(tensor.name)
652
653 def addOutputTensor(self, tensor):
654 self.currBasicBlock.addOutput(tensor.name)
655
Kevin Cheng550ccc52021-03-03 11:21:43 -0800656 def addOutput(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -0700657 if not self.currBasicBlock:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800658 raise Exception("addTensor called without valid basic block")
Eric Kunzee5e26762020-10-13 16:11:07 -0700659
Kevin Cheng550ccc52021-03-03 11:21:43 -0800660 name = "result-{}".format(self.currResultIdx)
Eric Kunzee5e26762020-10-13 16:11:07 -0700661 self.currResultIdx = self.currResultIdx + 1
662
Kevin Cheng550ccc52021-03-03 11:21:43 -0800663 tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
Eric Kunzee5e26762020-10-13 16:11:07 -0700664 self.currBasicBlock.addOutput(name)
665 return tens
666
Kevin Cheng550ccc52021-03-03 11:21:43 -0800667 def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700668
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700669 if op == tosa.Op.Op().CONST:
670 raise Exception("Use addConstTensor() to add CONST ops")
Eric Kunzee5e26762020-10-13 16:11:07 -0700671
Kevin Cheng550ccc52021-03-03 11:21:43 -0800672 return self.currBasicBlock.addOperator(
673 op, inputs, outputs, attributes, quant_info
674 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700675
Kevin Chengacb550f2021-06-29 15:32:19 -0700676 def setExpectedReturnCode(self, val, desc=""):
Eric Kunzee5e26762020-10-13 16:11:07 -0700677
Kevin Chengacb550f2021-06-29 15:32:19 -0700678 self.expectedReturnCode = val
Eric Kunzee5e26762020-10-13 16:11:07 -0700679 self.expectedFailureDesc = desc
680
Jared Smolensa9d53952021-08-24 23:48:19 +0000681 if val == TosaReturnCode.VALID:
682 self.expectedFailure = False
683 else:
684 # Unpredictable or error results are considered expected failures
685 # for conformance
686 self.expectedFailure = True
687
Eric Kunzee5e26762020-10-13 16:11:07 -0700688 def serialize(self):
689
690 builder = self.builder
691
692 Version.VersionStart(builder)
693 Version.VersionAdd_major(builder, TOSA_VERSION[0])
694 Version.VersionAdd_minor(builder, TOSA_VERSION[1])
695 Version.VersionAdd_patch(builder, TOSA_VERSION[2])
696 Version.VersionAdd_experimental(builder, TOSA_VERSION[3])
697 version = Version.VersionEnd(builder)
698
Kevin Cheng550ccc52021-03-03 11:21:43 -0800699 fbv_bb = TosaSerializer.serializeObjVec(
700 builder, self.basicBlocks, TosaGraph.TosaGraphStartBlocksVector
701 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700702
703 TosaGraph.TosaGraphStart(builder)
704 TosaGraph.TosaGraphAddVersion(builder, version)
705 TosaGraph.TosaGraphAddBlocks(builder, fbv_bb)
706 graph = TosaGraph.TosaGraphEnd(builder)
707
708 self.builder.Finish(graph)
709 return self.builder.Output()
710
711 def writeJson(self, tosa_filename):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800712 """Write a json test file so that it is fairly easy to pick up the test
713 and generate commands for third party tool"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700714 test_desc = dict()
715
Kevin Cheng550ccc52021-03-03 11:21:43 -0800716 test_desc["tosa_file"] = tosa_filename
Eric Kunzee5e26762020-10-13 16:11:07 -0700717 ifm_name = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700718 ifm_file = []
719 ofm_name = []
720 ofm_file = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700721
722 for b in self.basicBlocks:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800723 if b.name == "main":
Eric Kunzee5e26762020-10-13 16:11:07 -0700724 for i in b.inputs:
725 ifm_name.append(i)
Eric Kunzee5e26762020-10-13 16:11:07 -0700726 ifm_file.append(b.tensors[i].placeholderFilename)
727 for o in b.outputs:
728 ofm_name.append(o)
Eric Kunzee5e26762020-10-13 16:11:07 -0700729 # Make up an OFM filename here. One isn't generated until the reference tool is
730 # run, so any name is a good name
Kevin Cheng550ccc52021-03-03 11:21:43 -0800731 ofm_file.append("ref-{}.npy".format(o))
Eric Kunzee5e26762020-10-13 16:11:07 -0700732
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700733 test_desc["ifm_name"] = ifm_name
Kevin Cheng550ccc52021-03-03 11:21:43 -0800734 test_desc["ifm_file"] = ifm_file
Kevin Cheng550ccc52021-03-03 11:21:43 -0800735 test_desc["ofm_name"] = ofm_name
Kevin Cheng550ccc52021-03-03 11:21:43 -0800736 test_desc["ofm_file"] = ofm_file
Kevin Chengacb550f2021-06-29 15:32:19 -0700737 test_desc["expected_return_code"] = self.expectedReturnCode
Jared Smolensa9d53952021-08-24 23:48:19 +0000738 test_desc["expected_failure"] = self.expectedFailure
Eric Kunzee5e26762020-10-13 16:11:07 -0700739 if self.expectedFailureDesc:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800740 test_desc["expected_failure_desc"] = self.expectedFailureDesc
Eric Kunzee5e26762020-10-13 16:11:07 -0700741
Kevin Cheng550ccc52021-03-03 11:21:43 -0800742 return json.dumps(test_desc, indent=" ")
Eric Kunzee5e26762020-10-13 16:11:07 -0700743
744 def startBasicBlock(self, name):
745 self.currBasicBlock = TosaSerializerBasicBlock(name)
746 self.basicBlocks.append(self.currBasicBlock)
747
748 @staticmethod
749 def serializeStrVec(builder, vec, start_fcn):
750 fb_strs = [builder.CreateString(i) for i in vec]
751 start_fcn(builder, len(fb_strs))
752 for s in fb_strs[::-1]:
753 builder.PrependUOffsetTRelative(s)
Jared Smolensc94e63d2021-09-17 21:58:27 -0700754 # This try/except block supports both the Flatbuffers 2.x and 1.x APIs,
755 # defaulting to 2.x. If/when Flatbuffers 1.x support is deprecated, the
756 # try block and builder.EndVector(len) function calls can be removed.
757 try:
758 return builder.EndVector()
759 except TypeError:
760 return builder.EndVector(len(fb_strs))
Eric Kunzee5e26762020-10-13 16:11:07 -0700761
762 @staticmethod
Kevin Cheng82507d72021-06-17 16:01:59 -0700763 def serializeUint8Vec(builder, vec):
764 builder.StartVector(1, len(vec), 8)
765 for v in vec[::-1]:
766 builder.PrependUint8(v)
Jared Smolensc94e63d2021-09-17 21:58:27 -0700767 try:
768 return builder.EndVector()
769 except TypeError:
770 return builder.EndVector(len(vec))
Kevin Cheng82507d72021-06-17 16:01:59 -0700771
772 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700773 def serializeInt32Vec(builder, vec):
774 builder.StartVector(4, len(vec), 4)
775 for v in vec[::-1]:
776 builder.PrependInt32(v)
Jared Smolensc94e63d2021-09-17 21:58:27 -0700777 try:
778 return builder.EndVector()
779 except TypeError:
780 return builder.EndVector(len(vec))
Eric Kunzee5e26762020-10-13 16:11:07 -0700781
782 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -0800783 def serializeFpVec(builder, vec):
784 builder.StartVector(4, len(vec), 4)
785 for v in vec[::-1]:
786 builder.PrependFloat32(v)
Jared Smolensc94e63d2021-09-17 21:58:27 -0700787 try:
788 return builder.EndVector()
789 except TypeError:
790 return builder.EndVector(len(vec))
Kevin Cheng77d0f762020-11-24 10:26:32 -0800791
792 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700793 def serializeObjVec(builder, vec, start_fcn):
794 serialized_vec = []
795 for v in vec[::-1]:
796 serialized_vec.append(v.serialize(builder))
797
798 start_fcn(builder, len(vec))
799 for v in serialized_vec:
800 builder.PrependUOffsetTRelative(v)
Jared Smolensc94e63d2021-09-17 21:58:27 -0700801 try:
802 return builder.EndVector()
803 except TypeError:
804 return builder.EndVector(len(vec))
Eric Kunzee5e26762020-10-13 16:11:07 -0700805
806 @staticmethod
807 def toList(val):
808 if isinstance(val, list):
809 return val
810 else:
811 return [val]
812
813 @staticmethod
814 def setTosaVersion():
815 # Create a dummy flatbuffers file with the default version information
816 # There does not appear to be a better way to get a constant from a
817 # flatbuffer schema file
818 builder = flatbuffers.Builder(0)
819 Version.VersionStart(builder)
820 ver = Version.VersionEnd(builder)
821 TosaGraph.TosaGraphStart(builder)
822 TosaGraph.TosaGraphAddVersion(builder, ver)
823 gr = TosaGraph.TosaGraphEnd(builder)
824 builder.Finish(gr)
825
826 out = builder.Output()
827
828 gr = TosaGraph.TosaGraph()
829 root = gr.GetRootAsTosaGraph(out, 0)
830
831 # Store the version as a global variable so that it only needs to be
832 # generated once per process.
833 global TOSA_VERSION
Kevin Cheng550ccc52021-03-03 11:21:43 -0800834 TOSA_VERSION = [
835 root.Version()._major(),
836 root.Version()._minor(),
837 root.Version()._patch(),
838 root.Version()._experimental(),
839 ]