blob: 35dd9a2a6fe96631801794a84be882594f8f291d [file] [log] [blame]
Kevin Cheng3a478572021-01-22 17:21:02 -08001# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07002#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15#!/usr/bin/env python3
16
Kevin Cheng550ccc52021-03-03 11:21:43 -080017import os
18import sys
19import json
Eric Kunzee5e26762020-10-13 16:11:07 -070020import flatbuffers
21import numpy as np
Kevin Cheng82507d72021-06-17 16:01:59 -070022import struct
Eric Kunzee5e26762020-10-13 16:11:07 -070023from enum import Enum, IntEnum, unique
Kevin Cheng550ccc52021-03-03 11:21:43 -080024from tosa import (
25 TosaGraph,
26 TosaBasicBlock,
27 TosaTensor,
28 TosaOperator,
29 DType,
30 Op,
31 ResizeMode,
32 Version,
33)
Kevin Chengacb550f2021-06-29 15:32:19 -070034from tosa_ref_run import TosaReturnCode
Kevin Cheng550ccc52021-03-03 11:21:43 -080035
36# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
37parent_dir = os.path.dirname(os.path.realpath(__file__))
38sys.path.append(
39 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
40)
Eric Kunzee5e26762020-10-13 16:11:07 -070041import tosa
Eric Kunzee5e26762020-10-13 16:11:07 -070042
43# With the way flatc generates its python types, there is no programatic way
44# to get string names for the integer types. Manually maintain a string table
45# here.
Kevin Cheng82507d72021-06-17 16:01:59 -070046DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080047DTypeNames = [
48 "UNKNOWN",
49 "BOOL",
50 "UINT8",
51 "INT4",
52 "INT8",
53 "INT16",
54 "INT32",
55 "INT48",
56 "FLOAT",
57]
58
Kevin Cheng82507d72021-06-17 16:01:59 -070059ByteMask = np.uint64(0xFF)
Eric Kunzee5e26762020-10-13 16:11:07 -070060
Kevin Chengacb550f2021-06-29 15:32:19 -070061
Eric Kunzee5e26762020-10-13 16:11:07 -070062def dtype_str_to_val(name):
63
64 for i in range(len(DTypeNames)):
65 if name.casefold() == DTypeNames[i].casefold():
66 return i
Kevin Cheng550ccc52021-03-03 11:21:43 -080067 raise Exception("Unable to parse DType name {}".format(name))
Eric Kunzee5e26762020-10-13 16:11:07 -070068
69
70class TosaSerializerUnion:
Kevin Cheng550ccc52021-03-03 11:21:43 -080071 """This class handles encapsulating and serializing union types into flatbuffers"""
72
Eric Kunzee5e26762020-10-13 16:11:07 -070073 def __init__(self):
74
75 # A tuple of the start and end functions. Set by the options constructors below
76 self.optFcns = None
77
78 # The type from the tosa.Options enumeration. Set by the options constructors below.
79 self.utype = None
80
81 # Each of these lists is a tuple of the add function and the
82 # value being added. Set by the options constructors below.
83 self.ints = []
84 self.bools = []
85 self.floats = []
86 self.strings = []
87 self.intvecs = []
Kevin Cheng77d0f762020-11-24 10:26:32 -080088 self.fpvecs = []
Eric Kunzee5e26762020-10-13 16:11:07 -070089
90 def serialize(self, builder):
91
92 # We have to build strings and vectors first
93 strList = []
94 intVecList = []
Kevin Cheng77d0f762020-11-24 10:26:32 -080095 fpVecList = []
Eric Kunzee5e26762020-10-13 16:11:07 -070096
97 for fcn, val in self.strings:
98 strList.append((fcn, builder.CreateString(val)))
99
100 for fcn, val in self.intvecs:
101 intVecList.append((fcn, TosaSerializer.serializeInt32Vec(builder, val)))
102
Kevin Cheng77d0f762020-11-24 10:26:32 -0800103 for fcn, val in self.fpvecs:
104 fpVecList.append((fcn, TosaSerializer.serializeFpVec(builder, val)))
105
Eric Kunzee5e26762020-10-13 16:11:07 -0700106 startFcn, endFcn = self.optFcns
107
108 # Then serialize the options object from the list of primitives and
109 # other serialized values
110 startFcn(builder)
111 for fcn, val in self.ints:
112 fcn(builder, val)
113
114 for fcn, val in self.bools:
115 fcn(builder, val)
116
117 for fcn, val in self.floats:
118 fcn(builder, val)
119
120 for fcn, val in strList:
121 fcn(builder, val)
122
123 for fcn, val in intVecList:
124 fcn(builder, val)
125
Kevin Cheng77d0f762020-11-24 10:26:32 -0800126 for fcn, val in fpVecList:
127 fcn(builder, val)
128
Eric Kunzee5e26762020-10-13 16:11:07 -0700129 return endFcn(builder)
130
Kevin Cheng550ccc52021-03-03 11:21:43 -0800131
Eric Kunzee5e26762020-10-13 16:11:07 -0700132class TosaSerializerAttribute(TosaSerializerUnion):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800133 """This class handles encapsulating all of the enumerated types for attributes"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700134
135 def __init__(self):
136 super().__init__()
137
138 def Pool2dAttribute(self, kernel, stride, padding):
139 from tosa import Pool2dAttribute as a, Attribute
140
141 self.utype = Attribute.Attribute().Pool2dAttribute
142
143 self.optFcns = (a.Pool2dAttributeStart, a.Pool2dAttributeEnd)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800144 self.intvecs.append((a.Pool2dAttributeAddPadding, padding))
145 self.intvecs.append((a.Pool2dAttributeAddKernel, kernel))
146 self.intvecs.append((a.Pool2dAttributeAddStride, stride))
Eric Kunzee5e26762020-10-13 16:11:07 -0700147
148 def Conv2dAttribute(self, padding, stride, dilation):
149 from tosa import Conv2dAttribute as a, Attribute
150
151 self.utype = Attribute.Attribute().Conv2dAttribute
152 self.optFcns = (a.Conv2dAttributeStart, a.Conv2dAttributeEnd)
153
Kevin Cheng550ccc52021-03-03 11:21:43 -0800154 self.intvecs.append((a.Conv2dAttributeAddPadding, padding))
155 self.intvecs.append((a.Conv2dAttributeAddStride, stride))
156 self.intvecs.append((a.Conv2dAttributeAddDilation, dilation))
Eric Kunzee5e26762020-10-13 16:11:07 -0700157
158 def TransposeConv2DAttribute(self, outpad, stride, dilation, output_shape):
159 from tosa import TransposeConv2dAttribute as a, Attribute
160
161 self.utype = Attribute.Attribute().TransposeConv2dAttribute
162 self.optFcns = (a.TransposeConv2dAttributeStart, a.TransposeConv2dAttributeEnd)
163
Kevin Cheng550ccc52021-03-03 11:21:43 -0800164 self.intvecs.append((a.TransposeConv2dAttributeAddOutpad, outpad))
165 self.intvecs.append((a.TransposeConv2dAttributeAddStride, stride))
166 self.intvecs.append((a.TransposeConv2dAttributeAddDilation, dilation))
167 self.intvecs.append((a.TransposeConv2dAttributeAddOutputShape, output_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700168
169 def ReluNAttribute(self, maxint, maxfp):
170 from tosa import ReluNAttribute as a, Attribute
171
172 self.utype = Attribute.Attribute().ReluNAttribute
173 self.optFcns = (a.ReluNAttributeStart, a.ReluNAttributeEnd)
174
175 self.ints.append((a.ReluNAttributeAddMaxInt, maxint))
176 self.ints.append((a.ReluNAttributeAddMaxFp, maxfp))
177
Eric Kunzee5e26762020-10-13 16:11:07 -0700178 def AxisAttribute(self, axis):
179 from tosa import AxisAttribute as a, Attribute
180
181 self.utype = Attribute.Attribute().AxisAttribute
182 self.optFcns = (a.AxisAttributeStart, a.AxisAttributeEnd)
183
Kevin Cheng550ccc52021-03-03 11:21:43 -0800184 self.ints.append((a.AxisAttributeAddAxis, axis))
Eric Kunzee5e26762020-10-13 16:11:07 -0700185
186 def ReshapeAttribute(self, shape):
187 from tosa import ReshapeAttribute as a, Attribute
188
189 self.utype = Attribute.Attribute().ReshapeAttribute
190 self.optFcns = (a.ReshapeAttributeStart, a.ReshapeAttributeEnd)
191
Kevin Cheng550ccc52021-03-03 11:21:43 -0800192 self.intvecs.append((a.ReshapeAttributeAddShape, shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700193
194 def SliceAttribute(self, begin, size):
195 from tosa import SliceAttribute as a, Attribute
196
197 self.utype = Attribute.Attribute().SliceAttribute
198 self.optFcns = (a.SliceAttributeStart, a.SliceAttributeEnd)
199
Kevin Cheng550ccc52021-03-03 11:21:43 -0800200 self.intvecs.append((a.SliceAttributeAddBegin, begin))
201 self.intvecs.append((a.SliceAttributeAddSize, size))
Eric Kunzee5e26762020-10-13 16:11:07 -0700202
203 def TileAttribute(self, multiples):
204 from tosa import TileAttribute as a, Attribute
205
206 self.utype = Attribute.Attribute().TileAttribute
207 self.optFcns = (a.TileAttributeStart, a.TileAttributeEnd)
208
Kevin Cheng550ccc52021-03-03 11:21:43 -0800209 self.intvecs.append((a.TileAttributeAddMultiples, multiples))
Eric Kunzee5e26762020-10-13 16:11:07 -0700210
Kevin Cheng550ccc52021-03-03 11:21:43 -0800211 def ResizeAttribute(
212 self, output_size, stride, offset, shift, stride_fp, offset_fp, mode
213 ):
Eric Kunzee5e26762020-10-13 16:11:07 -0700214 from tosa import ResizeAttribute as a, Attribute
215
216 self.utype = Attribute.Attribute().ResizeAttribute
217 self.optFcns = (a.ResizeAttributeStart, a.ResizeAttributeEnd)
218
Kevin Cheng550ccc52021-03-03 11:21:43 -0800219 self.intvecs.append((a.ResizeAttributeAddOutputSize, output_size))
220 self.intvecs.append((a.ResizeAttributeAddStride, stride))
221 self.intvecs.append((a.ResizeAttributeAddOffset, offset))
222 self.ints.append((a.ResizeAttributeAddShift, shift))
223 self.fpvecs.append((a.ResizeAttributeAddStrideFp, stride_fp))
224 self.fpvecs.append((a.ResizeAttributeAddOffsetFp, offset_fp))
225 self.ints.append((a.ResizeAttributeAddMode, mode))
Eric Kunzee5e26762020-10-13 16:11:07 -0700226
227 def ClampAttribute(self, minint, maxint, minfp, maxfp):
228 from tosa import ClampAttribute as a, Attribute
229
230 self.utype = Attribute.Attribute().ClampAttribute
231 self.optFcns = (a.ClampAttributeStart, a.ClampAttributeEnd)
232
Kevin Cheng550ccc52021-03-03 11:21:43 -0800233 self.ints.append((a.ClampAttributeAddMinInt, minint))
234 self.ints.append((a.ClampAttributeAddMaxInt, maxint))
Eric Kunzee5e26762020-10-13 16:11:07 -0700235
Kevin Cheng550ccc52021-03-03 11:21:43 -0800236 self.ints.append((a.ClampAttributeAddMinFp, minfp))
237 self.ints.append((a.ClampAttributeAddMaxFp, maxfp))
Eric Kunzee5e26762020-10-13 16:11:07 -0700238
Kevin Cheng550ccc52021-03-03 11:21:43 -0800239 def RescaleAttribute(
240 self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel
241 ):
Eric Kunzee5e26762020-10-13 16:11:07 -0700242 from tosa import RescaleAttribute as a, Attribute
243
244 self.utype = Attribute.Attribute().RescaleAttribute
245 self.optFcns = (a.RescaleAttributeStart, a.RescaleAttributeEnd)
246
Kevin Cheng550ccc52021-03-03 11:21:43 -0800247 self.ints.append((a.RescaleAttributeAddInputZp, input_zp))
248 self.ints.append((a.RescaleAttributeAddOutputZp, output_zp))
249 self.intvecs.append((a.RescaleAttributeAddMultiplier, multiplier))
250 self.intvecs.append((a.RescaleAttributeAddShift, shift))
251 self.bools.append((a.RescaleAttributeAddScale32, scale32))
252 self.bools.append((a.RescaleAttributeAddDoubleRound, double_round))
253 self.bools.append((a.RescaleAttributeAddPerChannel, per_channel))
Eric Kunzee5e26762020-10-13 16:11:07 -0700254
Kevin Chengaee1fac2020-11-11 13:54:06 -0800255 def MulAttribute(self, shift):
256 from tosa import MulAttribute as a, Attribute
257
258 self.utype = Attribute.Attribute().MulAttribute
259 self.optFcns = (a.MulAttributeStart, a.MulAttributeEnd)
260
Kevin Cheng550ccc52021-03-03 11:21:43 -0800261 self.ints.append((a.MulAttributeAddShift, shift))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800262
263 def ArithmeticRightShiftAttribute(self, round):
264 from tosa import ArithmeticRightShiftAttribute as a, Attribute
265
266 self.utype = Attribute.Attribute().ArithmeticRightShiftAttribute
Kevin Cheng550ccc52021-03-03 11:21:43 -0800267 self.optFcns = (
268 a.ArithmeticRightShiftAttributeStart,
269 a.ArithmeticRightShiftAttributeEnd,
270 )
Kevin Chengaee1fac2020-11-11 13:54:06 -0800271
Kevin Cheng550ccc52021-03-03 11:21:43 -0800272 self.bools.append((a.ArithmeticRightShiftAttributeAddRound, round))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800273
Eric Kunzee5e26762020-10-13 16:11:07 -0700274 def CustomAttribute(self, identifier):
275 from tosa import CustomAttribute as a, Attribute
276
277 self.utype = Attribute.Attribute().CustomAttribute
278 self.optFcns = (a.CustomAttributeStart, a.CustomAttributeEnd)
279
Kevin Cheng550ccc52021-03-03 11:21:43 -0800280 self.strings.append((a.CustomAttributeAddIdentifier, identifier))
Eric Kunzee5e26762020-10-13 16:11:07 -0700281
282 def CondIfAttribute(self, then_branch, else_branch):
283 from tosa import CondIfAttribute as a, Attribute
284
285 self.utype = Attribute.Attribute().CondIfAttribute
286 self.optFcns = (a.CondIfAttributeStart, a.CondIfAttributeEnd)
287
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 self.strings.append((a.CondIfAttributeAddThenBranch, then_branch))
289 self.strings.append((a.CondIfAttributeAddElseBranch, else_branch))
Eric Kunzee5e26762020-10-13 16:11:07 -0700290
291 def WhileLoopAttribute(self, cond_branch, body_branch):
292 from tosa import WhileLoopAttribute as a, Attribute
293
294 self.utype = Attribute.Attribute().WhileLoopAttribute
295 self.optFcns = (a.WhileLoopAttributeStart, a.WhileLoopAttributeEnd)
296
Kevin Cheng550ccc52021-03-03 11:21:43 -0800297 self.strings.append((a.WhileLoopAttributeAddCondBranch, cond_branch))
298 self.strings.append((a.WhileLoopAttributeAddBodyBranch, body_branch))
299
Eric Kunzee5e26762020-10-13 16:11:07 -0700300
301class TosaSerializerQuantInfo(TosaSerializerUnion):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800302 """This class handles encapsulating all of the enumerated types for quantinfo types"""
303
Eric Kunzee5e26762020-10-13 16:11:07 -0700304 def __init__(self):
305 super().__init__()
306
307 def ConvQuantInfo(self, input_zp, weight_zp):
308 from tosa import ConvQuantInfo as q, QuantInfo
309
310 self.utype = QuantInfo.QuantInfo().ConvQuantInfo
311 self.optFcns = (q.ConvQuantInfoStart, q.ConvQuantInfoEnd)
312 self.ints.append((q.ConvQuantInfoAddInputZp, input_zp))
313 self.ints.append((q.ConvQuantInfoAddWeightZp, weight_zp))
314
315 def UnaryQuantInfo(self, input_zp, output_zp):
316 from tosa import UnaryQuantInfo as q, QuantInfo
317
318 self.utype = QuantInfo.QuantInfo().UnaryQuantInfo
319 self.optFcns = (q.UnaryQuantInfoStart, q.UnaryQuantInfoEnd)
320 self.ints.append((q.UnaryQuantInfoAddInputZp, input_zp))
321 self.ints.append((q.UnaryQuantInfoAddOutputZp, output_zp))
322
323 def MatMulQuantInfo(self, a_zp, b_zp):
324 from tosa import MatMulQuantInfo as q, QuantInfo
325
326 self.utype = QuantInfo.QuantInfo().MatMulQuantInfo
327 self.optFcns = (q.MatMulQuantInfoStart, q.MatMulQuantInfoEnd)
328 self.ints.append((q.MatMulQuantInfoAddAZp, a_zp))
329 self.ints.append((q.MatMulQuantInfoAddBZp, b_zp))
330
331 def PadQuantInfo(self, input_zp):
332 from tosa import PadQuantInfo as q, QuantInfo
333
334 self.utype = QuantInfo.QuantInfo().PadQuantInfo
335 self.optFcns = (q.PadQuantInfoStart, q.PadQuantInfoEnd)
336 self.ints.append((q.PadQuantInfoAddInputZp, input_zp))
337
Kevin Cheng550ccc52021-03-03 11:21:43 -0800338
Eric Kunzee5e26762020-10-13 16:11:07 -0700339class TosaSerializerTensor:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800340 def __init__(
341 self,
342 name,
343 shape,
344 dtype,
Kevin Cheng82507d72021-06-17 16:01:59 -0700345 data=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800346 placeholderFilename=None,
347 ):
Eric Kunzee5e26762020-10-13 16:11:07 -0700348 self.name = name
349
350 if isinstance(shape, np.ndarray):
351 shape = shape.astype(int).tolist()
352 shape = list(map(int, shape))
353
354 self.shape = shape
355 self.dtype = dtype
Eric Kunzee5e26762020-10-13 16:11:07 -0700356
Kevin Cheng82507d72021-06-17 16:01:59 -0700357 if isinstance(data, np.ndarray):
358 data = data.flatten().astype(int).tolist()
359 data = list(map(int, data))
360 self.data = data
361 else:
362 self.data = None
Eric Kunzee5e26762020-10-13 16:11:07 -0700363
364 # Filename for placeholder tensors. These get generated by the test generation
365 # process and are written to disk, but are considered input tensors by the network
366 # so they do not appear in the TOSA serialiazation. However, if we want to form a unit
367 # test around these input tensors, we can get the filename from here.
368 self.placeholderFilename = placeholderFilename
369
370 def __str__(self):
Kevin Cheng82507d72021-06-17 16:01:59 -0700371 str = "TosaSerializerTensor name: {} shape: {} dtype: {}".format(
Kevin Cheng550ccc52021-03-03 11:21:43 -0800372 self.name,
373 self.shape,
374 DTypeNames[self.dtype],
Kevin Cheng550ccc52021-03-03 11:21:43 -0800375 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700376 return str
377
Eric Kunzee5e26762020-10-13 16:11:07 -0700378 def setDtype(self, dtype):
379 self.dtype = dtype
380
Eric Kunzee5e26762020-10-13 16:11:07 -0700381 def serialize(self, builder):
382 fb_name = builder.CreateString(self.name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700383 fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
Kevin Cheng82507d72021-06-17 16:01:59 -0700384 if self.data:
385 u8_data = list()
386 # little endianess
387 if self.dtype == DType.BOOL:
388 for val in self.data:
389 val_u8 = np.uint8(val)
390 u8_data.append(val_u8)
Kevin Chenga9017402021-07-28 17:19:23 -0700391 elif self.dtype == DType.INT4:
392 in_size = len(self.data)
393 out_size = (in_size + 1) // 2
394 for i in range(out_size):
395 val_0 = self.data[2 * i]
396 if (2 * i + 1) < in_size:
397 val_1 = self.data[2 * i + 1]
398 else:
399 val_1 = 0
400 val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4)
401 val_u8 = np.uint8(val_i8)
402 u8_data.append(val_u8)
Kevin Cheng82507d72021-06-17 16:01:59 -0700403 elif self.dtype == DType.INT8:
404 for val in self.data:
405 val_u8 = np.uint8(val)
406 u8_data.append(val_u8)
407 elif self.dtype == DType.INT16:
408 for val in self.data:
409 val_u16 = np.uint16(val)
410 b0 = val_u16 & ByteMask
411 b1 = (val_u16 >> np.uint16(8)) & ByteMask
412 u8_data.extend([b0, b1])
413 elif self.dtype == DType.INT32:
414 for val in self.data:
415 val_u32 = np.uint32(val)
416 b0 = val_u32 & ByteMask
417 b1 = (val_u32 >> np.uint32(8)) & ByteMask
418 b2 = (val_u32 >> np.uint32(16)) & ByteMask
419 b3 = (val_u32 >> np.uint32(32)) & ByteMask
420 u8_data.extend([b0, b1, b2, b3])
421 elif self.dtype == DType.INT48:
422 for val in self.data:
423 val_u64 = np.uint64(val)
424 b0 = val_u64 & ByteMask
425 b1 = (val_u64 >> np.uint64(8)) & ByteMask
426 b2 = (val_u64 >> np.uint64(16)) & ByteMask
427 b3 = (val_u64 >> np.uint64(24)) & ByteMask
428 b4 = (val_u64 >> np.uint64(32)) & ByteMask
429 b5 = (val_u64 >> np.uint64(40)) & ByteMask
430 u8_data.extend([b0, b1, b2, b3, b4, b5])
431 elif self.dtype == DType.FLOAT:
432 for val in self.data:
Kevin Chengacb550f2021-06-29 15:32:19 -0700433 b = struct.pack("!f", val)
Kevin Cheng82507d72021-06-17 16:01:59 -0700434 u8_data.extend([b[3], b[2], b[1], b[0]])
435 else:
Kevin Chengacb550f2021-06-29 15:32:19 -0700436 raise Exception(
437 "unsupported data type {}".format(DTypeNames[self.dtype])
438 )
Kevin Cheng82507d72021-06-17 16:01:59 -0700439 fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data)
Eric Kunzee5e26762020-10-13 16:11:07 -0700440
441 TosaTensor.TosaTensorStart(builder)
442 TosaTensor.TosaTensorAddName(builder, fb_name)
443 TosaTensor.TosaTensorAddShape(builder, fb_shapes)
444 TosaTensor.TosaTensorAddType(builder, self.dtype)
Kevin Cheng82507d72021-06-17 16:01:59 -0700445 if self.data:
446 TosaTensor.TosaTensorAddData(builder, fb_data)
Eric Kunzee5e26762020-10-13 16:11:07 -0700447
448 return TosaTensor.TosaTensorEnd(builder)
449
Kevin Cheng550ccc52021-03-03 11:21:43 -0800450
Eric Kunzee5e26762020-10-13 16:11:07 -0700451class TosaSerializerOperator:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800452 def __init__(self, op, inputs, outputs, attributes=None, quantInfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700453 self.op = op
454 self.attributes = attributes
455 self.inputs = TosaSerializer.toList(inputs)
456 self.outputs = TosaSerializer.toList(outputs)
457 self.quantInfo = quantInfo
458
459 def __str__(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800460 str = "Op {}\n----\n".format(self.op)
Eric Kunzee5e26762020-10-13 16:11:07 -0700461
462 for i in self.inputs:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800463 str = str + " Input: {}\n".format(i)
Eric Kunzee5e26762020-10-13 16:11:07 -0700464 for o in self.outputs:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800465 str = str + " Output: {}\n".format(o)
Eric Kunzee5e26762020-10-13 16:11:07 -0700466
467 return str
468
469 def serialize(self, builder):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800470 fb_inputs = TosaSerializer.serializeStrVec(
471 builder, self.inputs, TosaOperator.TosaOperatorStartInputsVector
472 )
473 fb_outputs = TosaSerializer.serializeStrVec(
474 builder, self.outputs, TosaOperator.TosaOperatorStartOutputsVector
475 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700476 # Need to serialize quant_info and attributes enums still
477 if self.attributes is not None:
478 fb_attributes = self.attributes.serialize(builder)
479
480 if self.quantInfo is not None:
481 fb_qinfo = self.quantInfo.serialize(builder)
482
483 TosaOperator.TosaOperatorStart(builder)
484 TosaOperator.TosaOperatorAddOp(builder, self.op)
485 TosaOperator.TosaOperatorAddInputs(builder, fb_inputs)
486 TosaOperator.TosaOperatorAddOutputs(builder, fb_outputs)
487 if self.attributes is not None:
488 TosaOperator.TosaOperatorAddAttributeType(builder, self.attributes.utype)
489 TosaOperator.TosaOperatorAddAttribute(builder, fb_attributes)
490 if self.quantInfo is not None:
491 TosaOperator.TosaOperatorAddQuantInfoType(builder, self.quantInfo.utype)
492 TosaOperator.TosaOperatorAddQuantInfo(builder, fb_qinfo)
493
494 return TosaOperator.TosaOperatorEnd(builder)
495
Kevin Cheng550ccc52021-03-03 11:21:43 -0800496
Eric Kunzee5e26762020-10-13 16:11:07 -0700497class TosaSerializerBasicBlock:
498 def __init__(self, name):
499 self.name = name
500 self.operators = []
501
502 # Dict assures uniqueness, but allows us to look up by name
503 self.tensors = dict()
504
505 self.inputs = []
506 self.outputs = []
507
Kevin Cheng550ccc52021-03-03 11:21:43 -0800508 def addTensor(
509 self,
510 name,
511 shape,
512 dtype,
Kevin Cheng82507d72021-06-17 16:01:59 -0700513 data=None,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800514 placeholderFilename=None,
515 ):
Eric Kunzee5e26762020-10-13 16:11:07 -0700516 try:
517 # Someone already added this tensor.
Eric Kunzee5e26762020-10-13 16:11:07 -0700518 tens = self.tensors[name]
Eric Kunzee5e26762020-10-13 16:11:07 -0700519 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800520 self.tensors[name] = TosaSerializerTensor(
Kevin Cheng82507d72021-06-17 16:01:59 -0700521 name, shape, dtype, data, placeholderFilename
Kevin Cheng550ccc52021-03-03 11:21:43 -0800522 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700523
524 return self.tensors[name]
525
526 def addInput(self, name):
527 self.inputs.append(name)
528
529 def addOutput(self, name):
530 self.outputs.append(name)
531
Kevin Cheng550ccc52021-03-03 11:21:43 -0800532 def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None):
533 self.operators.append(
534 TosaSerializerOperator(op, inputs, outputs, attributes, quant_info)
535 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700536
537 def serialize(self, builder):
538 fb_name = builder.CreateString(self.name)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800539 fbv_inputs = TosaSerializer.serializeStrVec(
540 builder, list(self.inputs), TosaBasicBlock.TosaBasicBlockStartInputsVector
541 )
542 fbv_outputs = TosaSerializer.serializeStrVec(
543 builder, list(self.outputs), TosaBasicBlock.TosaBasicBlockStartOutputsVector
544 )
545 fbv_tensors = TosaSerializer.serializeObjVec(
546 builder,
547 list(self.tensors.values()),
548 TosaBasicBlock.TosaBasicBlockStartTensorsVector,
549 )
550 fbv_operators = TosaSerializer.serializeObjVec(
551 builder, self.operators, TosaBasicBlock.TosaBasicBlockStartOperatorsVector
552 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700553
554 TosaBasicBlock.TosaBasicBlockStart(builder)
555 TosaBasicBlock.TosaBasicBlockAddName(builder, fb_name)
556 TosaBasicBlock.TosaBasicBlockAddInputs(builder, fbv_inputs)
557 TosaBasicBlock.TosaBasicBlockAddOutputs(builder, fbv_outputs)
558 TosaBasicBlock.TosaBasicBlockAddTensors(builder, fbv_tensors)
559 TosaBasicBlock.TosaBasicBlockAddOperators(builder, fbv_operators)
560 return TosaBasicBlock.TosaBasicBlockEnd(builder)
561
Kevin Cheng550ccc52021-03-03 11:21:43 -0800562
Eric Kunzee5e26762020-10-13 16:11:07 -0700563@unique
564class TensorDir(IntEnum):
565 PLACEHOLDER = 0
566 CONST = 1
567 INTERMEDIATE = 2
568 RESULT = 3
569
Kevin Cheng550ccc52021-03-03 11:21:43 -0800570
Eric Kunzee5e26762020-10-13 16:11:07 -0700571class TosaSerializer:
572 def __init__(self, pathPrefix):
573
574 # Get the global TOSA version if not already defined
575 try:
576 TOSA_VERSION
577 except NameError:
578 TosaSerializer.setTosaVersion()
579
580 self.builder = flatbuffers.Builder(0)
581
582 self.basicBlocks = []
Kevin Cheng550ccc52021-03-03 11:21:43 -0800583 self.startBasicBlock("main")
Eric Kunzee5e26762020-10-13 16:11:07 -0700584 self.pathPrefix = pathPrefix
585
586 # Indicies used for adding/naming tensors
587 self.currInputIdx = 0
588 self.currConstIdx = 0
589 self.currLayerIdx = 1
590 self.currResultIdx = 0
591
592 # Is this an illegal test that is expected to fail?
Kevin Chengacb550f2021-06-29 15:32:19 -0700593 self.expectedReturnCode = TosaReturnCode.VALID
Kevin Cheng550ccc52021-03-03 11:21:43 -0800594 self.expectedFailureDesc = ""
Eric Kunzee5e26762020-10-13 16:11:07 -0700595
596 def __str__(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800597 str = ""
Eric Kunzee5e26762020-10-13 16:11:07 -0700598 for bb in self.basicBlocks:
599 str = str + bb.__str__()
600 return str
601
Kevin Cheng550ccc52021-03-03 11:21:43 -0800602 def addPlaceholder(self, shape, dtype, vals):
Eric Kunzee5e26762020-10-13 16:11:07 -0700603 if not self.currBasicBlock:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800604 raise Exception("addTensor called without valid basic block")
Eric Kunzee5e26762020-10-13 16:11:07 -0700605
Kevin Cheng550ccc52021-03-03 11:21:43 -0800606 name = "input-{}".format(self.currInputIdx)
607 filename = "{}.npy".format(name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700608 self.currInputIdx = self.currInputIdx + 1
609
Kevin Cheng550ccc52021-03-03 11:21:43 -0800610 tens = self.currBasicBlock.addTensor(name, shape, dtype, None, filename)
Eric Kunzee5e26762020-10-13 16:11:07 -0700611 # This is always an input to the block
612 self.currBasicBlock.addInput(name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700613
614 if vals is not None:
615 np.save(os.path.join(self.pathPrefix, filename), vals, False)
616
617 return tens
618
Kevin Cheng550ccc52021-03-03 11:21:43 -0800619 def addConst(self, shape, dtype, vals):
Eric Kunzee5e26762020-10-13 16:11:07 -0700620 if not self.currBasicBlock:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800621 raise Exception("addTensor called without valid basic block")
Eric Kunzee5e26762020-10-13 16:11:07 -0700622
Kevin Cheng550ccc52021-03-03 11:21:43 -0800623 name = "const-{}".format(self.currInputIdx)
624 filename = "{}.npy".format(name)
Eric Kunzee5e26762020-10-13 16:11:07 -0700625 self.currInputIdx = self.currInputIdx + 1
626
Kevin Cheng82507d72021-06-17 16:01:59 -0700627 tens = self.currBasicBlock.addTensor(name, shape, dtype, vals)
Eric Kunzee5e26762020-10-13 16:11:07 -0700628 # Add the operator now
629 self.currBasicBlock.addOperator(tosa.Op.Op().CONST, [], name)
630
Eric Kunzee5e26762020-10-13 16:11:07 -0700631 return tens
632
Kevin Cheng550ccc52021-03-03 11:21:43 -0800633 def addIntermediate(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -0700634
635 if not self.currBasicBlock:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800636 raise Exception("addTensor called without valid basic block")
Eric Kunzee5e26762020-10-13 16:11:07 -0700637
Kevin Cheng550ccc52021-03-03 11:21:43 -0800638 name = "layer-{}".format(self.currLayerIdx)
Eric Kunzee5e26762020-10-13 16:11:07 -0700639 self.currLayerIdx = self.currLayerIdx + 1
640
Kevin Cheng82507d72021-06-17 16:01:59 -0700641 tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
Eric Kunzee5e26762020-10-13 16:11:07 -0700642
643 return tens
644
645 def addInputTensor(self, tensor):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800646 self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -0700647 self.currBasicBlock.addInput(tensor.name)
648
649 def addOutputTensor(self, tensor):
650 self.currBasicBlock.addOutput(tensor.name)
651
Kevin Cheng550ccc52021-03-03 11:21:43 -0800652 def addOutput(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -0700653 if not self.currBasicBlock:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800654 raise Exception("addTensor called without valid basic block")
Eric Kunzee5e26762020-10-13 16:11:07 -0700655
Kevin Cheng550ccc52021-03-03 11:21:43 -0800656 name = "result-{}".format(self.currResultIdx)
Eric Kunzee5e26762020-10-13 16:11:07 -0700657 self.currResultIdx = self.currResultIdx + 1
658
Kevin Cheng550ccc52021-03-03 11:21:43 -0800659 tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
Eric Kunzee5e26762020-10-13 16:11:07 -0700660 self.currBasicBlock.addOutput(name)
661 return tens
662
Kevin Cheng550ccc52021-03-03 11:21:43 -0800663 def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None):
Eric Kunzee5e26762020-10-13 16:11:07 -0700664
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700665 if op == tosa.Op.Op().CONST:
666 raise Exception("Use addConstTensor() to add CONST ops")
Eric Kunzee5e26762020-10-13 16:11:07 -0700667
Kevin Cheng550ccc52021-03-03 11:21:43 -0800668 return self.currBasicBlock.addOperator(
669 op, inputs, outputs, attributes, quant_info
670 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700671
Kevin Chengacb550f2021-06-29 15:32:19 -0700672 def setExpectedReturnCode(self, val, desc=""):
Eric Kunzee5e26762020-10-13 16:11:07 -0700673
Kevin Chengacb550f2021-06-29 15:32:19 -0700674 self.expectedReturnCode = val
Eric Kunzee5e26762020-10-13 16:11:07 -0700675 self.expectedFailureDesc = desc
676
677 def serialize(self):
678
679 builder = self.builder
680
681 Version.VersionStart(builder)
682 Version.VersionAdd_major(builder, TOSA_VERSION[0])
683 Version.VersionAdd_minor(builder, TOSA_VERSION[1])
684 Version.VersionAdd_patch(builder, TOSA_VERSION[2])
685 Version.VersionAdd_experimental(builder, TOSA_VERSION[3])
686 version = Version.VersionEnd(builder)
687
Kevin Cheng550ccc52021-03-03 11:21:43 -0800688 fbv_bb = TosaSerializer.serializeObjVec(
689 builder, self.basicBlocks, TosaGraph.TosaGraphStartBlocksVector
690 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700691
692 TosaGraph.TosaGraphStart(builder)
693 TosaGraph.TosaGraphAddVersion(builder, version)
694 TosaGraph.TosaGraphAddBlocks(builder, fbv_bb)
695 graph = TosaGraph.TosaGraphEnd(builder)
696
697 self.builder.Finish(graph)
698 return self.builder.Output()
699
700 def writeJson(self, tosa_filename):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800701 """Write a json test file so that it is fairly easy to pick up the test
702 and generate commands for third party tool"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700703 test_desc = dict()
704
Kevin Cheng550ccc52021-03-03 11:21:43 -0800705 test_desc["tosa_file"] = tosa_filename
Eric Kunzee5e26762020-10-13 16:11:07 -0700706 ifm_name = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700707 ifm_file = []
708 ofm_name = []
709 ofm_file = []
Eric Kunzee5e26762020-10-13 16:11:07 -0700710
711 for b in self.basicBlocks:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800712 if b.name == "main":
Eric Kunzee5e26762020-10-13 16:11:07 -0700713 for i in b.inputs:
714 ifm_name.append(i)
Eric Kunzee5e26762020-10-13 16:11:07 -0700715 ifm_file.append(b.tensors[i].placeholderFilename)
716 for o in b.outputs:
717 ofm_name.append(o)
Eric Kunzee5e26762020-10-13 16:11:07 -0700718 # Make up an OFM filename here. One isn't generated until the reference tool is
719 # run, so any name is a good name
Kevin Cheng550ccc52021-03-03 11:21:43 -0800720 ofm_file.append("ref-{}.npy".format(o))
Eric Kunzee5e26762020-10-13 16:11:07 -0700721
Kevin Chengcd79f0e2021-06-03 15:00:34 -0700722 test_desc["ifm_name"] = ifm_name
Kevin Cheng550ccc52021-03-03 11:21:43 -0800723 test_desc["ifm_file"] = ifm_file
Kevin Cheng550ccc52021-03-03 11:21:43 -0800724 test_desc["ofm_name"] = ofm_name
Kevin Cheng550ccc52021-03-03 11:21:43 -0800725 test_desc["ofm_file"] = ofm_file
Kevin Chengacb550f2021-06-29 15:32:19 -0700726 test_desc["expected_return_code"] = self.expectedReturnCode
Eric Kunzee5e26762020-10-13 16:11:07 -0700727 if self.expectedFailureDesc:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800728 test_desc["expected_failure_desc"] = self.expectedFailureDesc
Eric Kunzee5e26762020-10-13 16:11:07 -0700729
Kevin Cheng550ccc52021-03-03 11:21:43 -0800730 return json.dumps(test_desc, indent=" ")
Eric Kunzee5e26762020-10-13 16:11:07 -0700731
732 def startBasicBlock(self, name):
733 self.currBasicBlock = TosaSerializerBasicBlock(name)
734 self.basicBlocks.append(self.currBasicBlock)
735
736 @staticmethod
737 def serializeStrVec(builder, vec, start_fcn):
738 fb_strs = [builder.CreateString(i) for i in vec]
739 start_fcn(builder, len(fb_strs))
740 for s in fb_strs[::-1]:
741 builder.PrependUOffsetTRelative(s)
742 return builder.EndVector(len(fb_strs))
743
744 @staticmethod
Kevin Cheng82507d72021-06-17 16:01:59 -0700745 def serializeUint8Vec(builder, vec):
746 builder.StartVector(1, len(vec), 8)
747 for v in vec[::-1]:
748 builder.PrependUint8(v)
749 return builder.EndVector(len(vec))
750
751 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700752 def serializeInt32Vec(builder, vec):
753 builder.StartVector(4, len(vec), 4)
754 for v in vec[::-1]:
755 builder.PrependInt32(v)
756 return builder.EndVector(len(vec))
757
758 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -0800759 def serializeFpVec(builder, vec):
760 builder.StartVector(4, len(vec), 4)
761 for v in vec[::-1]:
762 builder.PrependFloat32(v)
763 return builder.EndVector(len(vec))
764
765 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700766 def serializeObjVec(builder, vec, start_fcn):
767 serialized_vec = []
768 for v in vec[::-1]:
769 serialized_vec.append(v.serialize(builder))
770
771 start_fcn(builder, len(vec))
772 for v in serialized_vec:
773 builder.PrependUOffsetTRelative(v)
774 return builder.EndVector(len(vec))
775
776 @staticmethod
777 def toList(val):
778 if isinstance(val, list):
779 return val
780 else:
781 return [val]
782
783 @staticmethod
784 def setTosaVersion():
785 # Create a dummy flatbuffers file with the default version information
786 # There does not appear to be a better way to get a constant from a
787 # flatbuffer schema file
788 builder = flatbuffers.Builder(0)
789 Version.VersionStart(builder)
790 ver = Version.VersionEnd(builder)
791 TosaGraph.TosaGraphStart(builder)
792 TosaGraph.TosaGraphAddVersion(builder, ver)
793 gr = TosaGraph.TosaGraphEnd(builder)
794 builder.Finish(gr)
795
796 out = builder.Output()
797
798 gr = TosaGraph.TosaGraph()
799 root = gr.GetRootAsTosaGraph(out, 0)
800
801 # Store the version as a global variable so that it only needs to be
802 # generated once per process.
803 global TOSA_VERSION
Kevin Cheng550ccc52021-03-03 11:21:43 -0800804 TOSA_VERSION = [
805 root.Version()._major(),
806 root.Version()._minor(),
807 root.Version()._patch(),
808 root.Version()._experimental(),
809 ]