blob: b11f9cd993df37e9eb0d558c7efc04aa3fae08d2 [file] [log] [blame]
Kevin Chengfea5a372021-10-11 18:38:47 +00001# Copyright (c) 2020-2021, ARM Limited.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15#!/usr/bin/env python3
16
17import os
18import sys
19import json
20import flatbuffers
21import numpy as np
22import struct
23from enum import Enum, IntEnum, unique
24from tosa import (
25 TosaGraph,
26 TosaBasicBlock,
27 TosaTensor,
28 TosaOperator,
29 DType,
30 Op,
31 ResizeMode,
32 Version,
33)
34from tosa_ref_run import TosaReturnCode
35
36import tosa
37
Kevin Chenge6563f52021-10-20 12:12:02 -070038# Keep version number in sync with the version default value with schema/tosa.fbs
Kevin Chengb97cb1d2021-10-14 11:53:39 -070039TOSA_VERSION_MAJOR = 0
Eric Kunzea687b612021-11-03 17:02:57 -070040TOSA_VERSION_MINOR = 24
Kevin Chengb97cb1d2021-10-14 11:53:39 -070041TOSA_VERSION_PATCH = 0
Eric Kunzea687b612021-11-03 17:02:57 -070042TOSA_VERSION_DRAFT = True
Kevin Chengb97cb1d2021-10-14 11:53:39 -070043TOSA_VERSION = [TOSA_VERSION_MAJOR,
44 TOSA_VERSION_MINOR,
45 TOSA_VERSION_PATCH,
46 TOSA_VERSION_DRAFT]
Kevin Chengfea5a372021-10-11 18:38:47 +000047# With the way flatc generates its python types, there is no programatic way
48# to get string names for the integer types. Manually maintain a string table
49# here.
50DType = tosa.DType.DType()
51DTypeNames = [
52 "UNKNOWN",
53 "BOOL",
54 "UINT8",
55 "INT4",
56 "INT8",
57 "INT16",
58 "INT32",
59 "INT48",
60 "FLOAT",
61]
62
Frederick Liardet62c1e952022-08-24 10:07:25 +010063# File identifier needs to be kept in sync with schema/tosa.fbs
64TOSA_GRAPH_IDENTIFIER = b"\x54\x4F\x53\x41"
65
Kevin Chengfea5a372021-10-11 18:38:47 +000066ByteMask = np.uint64(0xFF)
67
68
69def dtype_str_to_val(name):
70
71 for i in range(len(DTypeNames)):
72 if name.casefold() == DTypeNames[i].casefold():
73 return i
74 raise Exception("Unable to parse DType name {}".format(name))
75
76
77class TosaSerializerUnion:
78 """This class handles encapsulating and serializing union types into flatbuffers"""
79
80 def __init__(self):
81
82 # A tuple of the start and end functions. Set by the options constructors below
83 self.optFcns = None
84
85 # The type from the tosa.Options enumeration. Set by the options constructors below.
86 self.utype = None
87
88 # Each of these lists is a tuple of the add function and the
89 # value being added. Set by the options constructors below.
90 self.ints = []
91 self.bools = []
92 self.floats = []
93 self.strings = []
94 self.intvecs = []
95 self.fpvecs = []
96
97 def serialize(self, builder):
98
99 # We have to build strings and vectors first
100 strList = []
101 intVecList = []
102 fpVecList = []
103
104 for fcn, val in self.strings:
105 strList.append((fcn, builder.CreateString(val)))
106
107 for fcn, val in self.intvecs:
108 intVecList.append((fcn, TosaSerializer.serializeInt32Vec(builder, val)))
109
110 for fcn, val in self.fpvecs:
111 fpVecList.append((fcn, TosaSerializer.serializeFpVec(builder, val)))
112
113 startFcn, endFcn = self.optFcns
114
115 # Then serialize the options object from the list of primitives and
116 # other serialized values
117 startFcn(builder)
118 for fcn, val in self.ints:
119 fcn(builder, val)
120
121 for fcn, val in self.bools:
122 fcn(builder, val)
123
124 for fcn, val in self.floats:
125 fcn(builder, val)
126
127 for fcn, val in strList:
128 fcn(builder, val)
129
130 for fcn, val in intVecList:
131 fcn(builder, val)
132
133 for fcn, val in fpVecList:
134 fcn(builder, val)
135
136 return endFcn(builder)
137
138
139class TosaSerializerAttribute(TosaSerializerUnion):
140 """This class handles encapsulating all of the enumerated types for attributes"""
141
142 def __init__(self):
143 super().__init__()
144
145 def PoolAttribute(self, kernel, stride, padding):
146 from tosa import PoolAttribute as a, Attribute
147
148 self.utype = Attribute.Attribute().PoolAttribute
149
150 self.optFcns = (a.PoolAttributeStart, a.PoolAttributeEnd)
151 self.intvecs.append((a.PoolAttributeAddPadding, padding))
152 self.intvecs.append((a.PoolAttributeAddKernel, kernel))
153 self.intvecs.append((a.PoolAttributeAddStride, stride))
154
155 def ConvAttribute(self, padding, stride, dilation):
156 from tosa import ConvAttribute as a, Attribute
157
158 self.utype = Attribute.Attribute().ConvAttribute
159 self.optFcns = (a.ConvAttributeStart, a.ConvAttributeEnd)
160
161 self.intvecs.append((a.ConvAttributeAddPadding, padding))
162 self.intvecs.append((a.ConvAttributeAddStride, stride))
163 self.intvecs.append((a.ConvAttributeAddDilation, dilation))
164
165 def TransposeConvAttribute(self, outpad, stride, dilation, output_shape):
166 from tosa import TransposeConvAttribute as a, Attribute
167
168 self.utype = Attribute.Attribute().TransposeConvAttribute
169 self.optFcns = (a.TransposeConvAttributeStart, a.TransposeConvAttributeEnd)
170
171 self.intvecs.append((a.TransposeConvAttributeAddOutpad, outpad))
172 self.intvecs.append((a.TransposeConvAttributeAddStride, stride))
173 self.intvecs.append((a.TransposeConvAttributeAddDilation, dilation))
174 self.intvecs.append((a.TransposeConvAttributeAddOutputShape, output_shape))
175
Kevin Cheng38d214c2021-10-15 15:49:19 -0700176 def PadAttribute(self, padding, pad_const_int, pad_const_fp):
177 from tosa import PadAttribute as a, Attribute
Kevin Chengfea5a372021-10-11 18:38:47 +0000178
Kevin Cheng38d214c2021-10-15 15:49:19 -0700179 self.utype = Attribute.Attribute().PadAttribute
180 self.optFcns = (a.PadAttributeStart, a.PadAttributeEnd)
Kevin Chengfea5a372021-10-11 18:38:47 +0000181
Kevin Cheng38d214c2021-10-15 15:49:19 -0700182 self.intvecs.append((a.PadAttributeAddPadding, padding))
183 self.ints.append((a.PadAttributeAddPadConstInt, pad_const_int))
184 self.floats.append((a.PadAttributeAddPadConstFp, pad_const_fp))
Kevin Chengfea5a372021-10-11 18:38:47 +0000185
186 def AxisAttribute(self, axis):
187 from tosa import AxisAttribute as a, Attribute
188
189 self.utype = Attribute.Attribute().AxisAttribute
190 self.optFcns = (a.AxisAttributeStart, a.AxisAttributeEnd)
191
192 self.ints.append((a.AxisAttributeAddAxis, axis))
193
194 def ReshapeAttribute(self, shape):
195 from tosa import ReshapeAttribute as a, Attribute
196
197 self.utype = Attribute.Attribute().ReshapeAttribute
198 self.optFcns = (a.ReshapeAttributeStart, a.ReshapeAttributeEnd)
199
200 self.intvecs.append((a.ReshapeAttributeAddShape, shape))
201
202 def SliceAttribute(self, begin, size):
203 from tosa import SliceAttribute as a, Attribute
204
205 self.utype = Attribute.Attribute().SliceAttribute
206 self.optFcns = (a.SliceAttributeStart, a.SliceAttributeEnd)
207
208 self.intvecs.append((a.SliceAttributeAddBegin, begin))
209 self.intvecs.append((a.SliceAttributeAddSize, size))
210
211 def TileAttribute(self, multiples):
212 from tosa import TileAttribute as a, Attribute
213
214 self.utype = Attribute.Attribute().TileAttribute
215 self.optFcns = (a.TileAttributeStart, a.TileAttributeEnd)
216
217 self.intvecs.append((a.TileAttributeAddMultiples, multiples))
218
219 def ResizeAttribute(
220 self, output_size, stride, offset, shift, stride_fp, offset_fp, mode
221 ):
222 from tosa import ResizeAttribute as a, Attribute
223
224 self.utype = Attribute.Attribute().ResizeAttribute
225 self.optFcns = (a.ResizeAttributeStart, a.ResizeAttributeEnd)
226
227 self.intvecs.append((a.ResizeAttributeAddOutputSize, output_size))
228 self.intvecs.append((a.ResizeAttributeAddStride, stride))
229 self.intvecs.append((a.ResizeAttributeAddOffset, offset))
230 self.ints.append((a.ResizeAttributeAddShift, shift))
231 self.fpvecs.append((a.ResizeAttributeAddStrideFp, stride_fp))
232 self.fpvecs.append((a.ResizeAttributeAddOffsetFp, offset_fp))
233 self.ints.append((a.ResizeAttributeAddMode, mode))
234
235 def ClampAttribute(self, minint, maxint, minfp, maxfp):
236 from tosa import ClampAttribute as a, Attribute
237
238 self.utype = Attribute.Attribute().ClampAttribute
239 self.optFcns = (a.ClampAttributeStart, a.ClampAttributeEnd)
240
241 self.ints.append((a.ClampAttributeAddMinInt, minint))
242 self.ints.append((a.ClampAttributeAddMaxInt, maxint))
243
244 self.ints.append((a.ClampAttributeAddMinFp, minfp))
245 self.ints.append((a.ClampAttributeAddMaxFp, maxfp))
246
247 def RescaleAttribute(
248 self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel
249 ):
250 from tosa import RescaleAttribute as a, Attribute
251
252 self.utype = Attribute.Attribute().RescaleAttribute
253 self.optFcns = (a.RescaleAttributeStart, a.RescaleAttributeEnd)
254
255 self.ints.append((a.RescaleAttributeAddInputZp, input_zp))
256 self.ints.append((a.RescaleAttributeAddOutputZp, output_zp))
257 self.intvecs.append((a.RescaleAttributeAddMultiplier, multiplier))
258 self.intvecs.append((a.RescaleAttributeAddShift, shift))
259 self.bools.append((a.RescaleAttributeAddScale32, scale32))
260 self.bools.append((a.RescaleAttributeAddDoubleRound, double_round))
261 self.bools.append((a.RescaleAttributeAddPerChannel, per_channel))
262
263 def MulAttribute(self, shift):
264 from tosa import MulAttribute as a, Attribute
265
266 self.utype = Attribute.Attribute().MulAttribute
267 self.optFcns = (a.MulAttributeStart, a.MulAttributeEnd)
268
269 self.ints.append((a.MulAttributeAddShift, shift))
270
271 def ArithmeticRightShiftAttribute(self, round):
272 from tosa import ArithmeticRightShiftAttribute as a, Attribute
273
274 self.utype = Attribute.Attribute().ArithmeticRightShiftAttribute
275 self.optFcns = (
276 a.ArithmeticRightShiftAttributeStart,
277 a.ArithmeticRightShiftAttributeEnd,
278 )
279
280 self.bools.append((a.ArithmeticRightShiftAttributeAddRound, round))
281
Kevin Chengfea5a372021-10-11 18:38:47 +0000282 def CondIfAttribute(self, then_branch, else_branch):
283 from tosa import CondIfAttribute as a, Attribute
284
285 self.utype = Attribute.Attribute().CondIfAttribute
286 self.optFcns = (a.CondIfAttributeStart, a.CondIfAttributeEnd)
287
288 self.strings.append((a.CondIfAttributeAddThenBranch, then_branch))
289 self.strings.append((a.CondIfAttributeAddElseBranch, else_branch))
290
291 def WhileLoopAttribute(self, cond_branch, body_branch):
292 from tosa import WhileLoopAttribute as a, Attribute
293
294 self.utype = Attribute.Attribute().WhileLoopAttribute
295 self.optFcns = (a.WhileLoopAttributeStart, a.WhileLoopAttributeEnd)
296
297 self.strings.append((a.WhileLoopAttributeAddCondBranch, cond_branch))
298 self.strings.append((a.WhileLoopAttributeAddBodyBranch, body_branch))
299
Kevin Cheng38d214c2021-10-15 15:49:19 -0700300 def TransposeAttribute(self, perm):
301 from tosa import TransposeAttribute as a, Attribute
302
303 self.utype = Attribute.Attribute().TransposeAttribute
304 self.optFcns = (a.TransposeAttributeStart, a.TransposeAttributeEnd)
305
306 self.intvecs.append((a.TransposeAttributeAddPerm, perm))
307
308 def TableAttribute(self, table):
309 from tosa import TableAttribute as a, Attribute
310
311 self.utype = Attribute.Attribute().TableAttribute
312 self.optFcns = (a.TableAttributeStart, a.TableAttributeEnd)
313
314 self.intvecs.append((a.TableAttributeAddTable, table))
Kevin Chengfea5a372021-10-11 18:38:47 +0000315
316class TosaSerializerQuantInfo(TosaSerializerUnion):
317 """This class handles encapsulating all of the enumerated types for quantinfo types"""
318
319 def __init__(self):
320 super().__init__()
321
322 def ConvQuantInfo(self, input_zp, weight_zp):
323 from tosa import ConvQuantInfo as q, QuantInfo
324
325 self.utype = QuantInfo.QuantInfo().ConvQuantInfo
326 self.optFcns = (q.ConvQuantInfoStart, q.ConvQuantInfoEnd)
327 self.ints.append((q.ConvQuantInfoAddInputZp, input_zp))
328 self.ints.append((q.ConvQuantInfoAddWeightZp, weight_zp))
329
330 def UnaryQuantInfo(self, input_zp, output_zp):
331 from tosa import UnaryQuantInfo as q, QuantInfo
332
333 self.utype = QuantInfo.QuantInfo().UnaryQuantInfo
334 self.optFcns = (q.UnaryQuantInfoStart, q.UnaryQuantInfoEnd)
335 self.ints.append((q.UnaryQuantInfoAddInputZp, input_zp))
336 self.ints.append((q.UnaryQuantInfoAddOutputZp, output_zp))
337
338 def MatMulQuantInfo(self, a_zp, b_zp):
339 from tosa import MatMulQuantInfo as q, QuantInfo
340
341 self.utype = QuantInfo.QuantInfo().MatMulQuantInfo
342 self.optFcns = (q.MatMulQuantInfoStart, q.MatMulQuantInfoEnd)
343 self.ints.append((q.MatMulQuantInfoAddAZp, a_zp))
344 self.ints.append((q.MatMulQuantInfoAddBZp, b_zp))
345
346 def PadQuantInfo(self, input_zp):
347 from tosa import PadQuantInfo as q, QuantInfo
348
349 self.utype = QuantInfo.QuantInfo().PadQuantInfo
350 self.optFcns = (q.PadQuantInfoStart, q.PadQuantInfoEnd)
351 self.ints.append((q.PadQuantInfoAddInputZp, input_zp))
352
353
354class TosaSerializerTensor:
355 def __init__(
356 self,
357 name,
358 shape,
359 dtype,
360 data=None,
361 placeholderFilename=None,
362 ):
363 self.name = name
364
365 if isinstance(shape, np.ndarray):
366 shape = shape.astype(int).tolist()
367 shape = list(map(int, shape))
368
369 self.shape = shape
370 self.dtype = dtype
371
372 if isinstance(data, np.ndarray):
373 data = data.flatten().astype(int).tolist()
374 data = list(map(int, data))
375 self.data = data
376 elif isinstance(data, list):
377 data = list(map(int, data))
378 self.data = data
379 else:
380 self.data = None
381
382 # Filename for placeholder tensors. These get generated by the test generation
383 # process and are written to disk, but are considered input tensors by the network
384 # so they do not appear in the TOSA serialiazation. However, if we want to form a unit
385 # test around these input tensors, we can get the filename from here.
386 self.placeholderFilename = placeholderFilename
387
388 def __str__(self):
389 str = "TosaSerializerTensor name: {} shape: {} dtype: {}".format(
390 self.name,
391 self.shape,
392 DTypeNames[self.dtype],
393 )
394 return str
395
396 def setDtype(self, dtype):
397 self.dtype = dtype
398
399 def serialize(self, builder):
400 fb_name = builder.CreateString(self.name)
401 fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
402 if self.data:
403 u8_data = list()
404 # little endianess
405 if self.dtype == DType.BOOL:
406 for val in self.data:
407 val_u8 = np.uint8(val)
408 u8_data.append(val_u8)
409 elif self.dtype == DType.INT4:
410 in_size = len(self.data)
411 out_size = (in_size + 1) // 2
412 for i in range(out_size):
413 val_0 = self.data[2 * i]
414 if (2 * i + 1) < in_size:
415 val_1 = self.data[2 * i + 1]
416 else:
417 val_1 = 0
418 val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4)
419 val_u8 = np.uint8(val_i8)
420 u8_data.append(val_u8)
421 elif self.dtype == DType.INT8:
422 for val in self.data:
423 val_u8 = np.uint8(val)
424 u8_data.append(val_u8)
425 elif self.dtype == DType.INT16:
426 for val in self.data:
427 val_u16 = np.uint16(val)
428 b0 = val_u16 & ByteMask
429 b1 = (val_u16 >> np.uint16(8)) & ByteMask
430 u8_data.extend([b0, b1])
431 elif self.dtype == DType.INT32:
432 for val in self.data:
433 val_u32 = np.uint32(val)
434 b0 = val_u32 & ByteMask
435 b1 = (val_u32 >> np.uint32(8)) & ByteMask
436 b2 = (val_u32 >> np.uint32(16)) & ByteMask
Kevin Cheng6b078ca2021-10-13 23:12:50 -0700437 b3 = (val_u32 >> np.uint32(24)) & ByteMask
Kevin Chengfea5a372021-10-11 18:38:47 +0000438 u8_data.extend([b0, b1, b2, b3])
439 elif self.dtype == DType.INT48:
440 for val in self.data:
441 val_u64 = np.uint64(val)
442 b0 = val_u64 & ByteMask
443 b1 = (val_u64 >> np.uint64(8)) & ByteMask
444 b2 = (val_u64 >> np.uint64(16)) & ByteMask
445 b3 = (val_u64 >> np.uint64(24)) & ByteMask
446 b4 = (val_u64 >> np.uint64(32)) & ByteMask
447 b5 = (val_u64 >> np.uint64(40)) & ByteMask
448 u8_data.extend([b0, b1, b2, b3, b4, b5])
449 elif self.dtype == DType.FLOAT:
450 for val in self.data:
451 b = struct.pack("!f", val)
452 u8_data.extend([b[3], b[2], b[1], b[0]])
453 else:
454 raise Exception(
455 "unsupported data type {}".format(DTypeNames[self.dtype])
456 )
457 fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data)
458
459 TosaTensor.TosaTensorStart(builder)
460 TosaTensor.TosaTensorAddName(builder, fb_name)
461 TosaTensor.TosaTensorAddShape(builder, fb_shapes)
462 TosaTensor.TosaTensorAddType(builder, self.dtype)
463 if self.data:
464 TosaTensor.TosaTensorAddData(builder, fb_data)
465
466 return TosaTensor.TosaTensorEnd(builder)
467
468
469class TosaSerializerOperator:
470 def __init__(self, op, inputs, outputs, attributes=None, quantInfo=None):
471 self.op = op
472 self.attributes = attributes
473 self.inputs = TosaSerializer.toList(inputs)
474 self.outputs = TosaSerializer.toList(outputs)
475 self.quantInfo = quantInfo
476
477 def __str__(self):
478 str = "Op {}\n----\n".format(self.op)
479
480 for i in self.inputs:
481 str = str + " Input: {}\n".format(i)
482 for o in self.outputs:
483 str = str + " Output: {}\n".format(o)
484
485 return str
486
487 def serialize(self, builder):
488 fb_inputs = TosaSerializer.serializeStrVec(
489 builder, self.inputs, TosaOperator.TosaOperatorStartInputsVector
490 )
491 fb_outputs = TosaSerializer.serializeStrVec(
492 builder, self.outputs, TosaOperator.TosaOperatorStartOutputsVector
493 )
494 # Need to serialize quant_info and attributes enums still
495 if self.attributes is not None:
496 fb_attributes = self.attributes.serialize(builder)
497
498 if self.quantInfo is not None:
499 fb_qinfo = self.quantInfo.serialize(builder)
500
501 TosaOperator.TosaOperatorStart(builder)
502 TosaOperator.TosaOperatorAddOp(builder, self.op)
503 TosaOperator.TosaOperatorAddInputs(builder, fb_inputs)
504 TosaOperator.TosaOperatorAddOutputs(builder, fb_outputs)
505 if self.attributes is not None:
506 TosaOperator.TosaOperatorAddAttributeType(builder, self.attributes.utype)
507 TosaOperator.TosaOperatorAddAttribute(builder, fb_attributes)
508 if self.quantInfo is not None:
509 TosaOperator.TosaOperatorAddQuantInfoType(builder, self.quantInfo.utype)
510 TosaOperator.TosaOperatorAddQuantInfo(builder, fb_qinfo)
511
512 return TosaOperator.TosaOperatorEnd(builder)
513
514
515class TosaSerializerBasicBlock:
516 def __init__(self, name):
517 self.name = name
518 self.operators = []
519
520 # Dict assures uniqueness, but allows us to look up by name
521 self.tensors = dict()
522
523 self.inputs = []
524 self.outputs = []
525
526 def addTensor(
527 self,
528 name,
529 shape,
530 dtype,
531 data=None,
532 placeholderFilename=None,
533 ):
534 try:
535 # Someone already added this tensor.
536 tens = self.tensors[name]
537 except KeyError:
538 self.tensors[name] = TosaSerializerTensor(
539 name, shape, dtype, data, placeholderFilename
540 )
541
542 return self.tensors[name]
543
544 def addInput(self, name):
545 self.inputs.append(name)
546
547 def addOutput(self, name):
548 self.outputs.append(name)
549
550 def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None):
551 self.operators.append(
552 TosaSerializerOperator(op, inputs, outputs, attributes, quant_info)
553 )
554
555 def serialize(self, builder):
556 fb_name = builder.CreateString(self.name)
557 fbv_inputs = TosaSerializer.serializeStrVec(
558 builder, list(self.inputs), TosaBasicBlock.TosaBasicBlockStartInputsVector
559 )
560 fbv_outputs = TosaSerializer.serializeStrVec(
561 builder, list(self.outputs), TosaBasicBlock.TosaBasicBlockStartOutputsVector
562 )
563 fbv_tensors = TosaSerializer.serializeObjVec(
564 builder,
565 list(self.tensors.values()),
566 TosaBasicBlock.TosaBasicBlockStartTensorsVector,
567 )
568 fbv_operators = TosaSerializer.serializeObjVec(
569 builder, self.operators, TosaBasicBlock.TosaBasicBlockStartOperatorsVector
570 )
571
572 TosaBasicBlock.TosaBasicBlockStart(builder)
573 TosaBasicBlock.TosaBasicBlockAddName(builder, fb_name)
574 TosaBasicBlock.TosaBasicBlockAddInputs(builder, fbv_inputs)
575 TosaBasicBlock.TosaBasicBlockAddOutputs(builder, fbv_outputs)
576 TosaBasicBlock.TosaBasicBlockAddTensors(builder, fbv_tensors)
577 TosaBasicBlock.TosaBasicBlockAddOperators(builder, fbv_operators)
578 return TosaBasicBlock.TosaBasicBlockEnd(builder)
579
580
581@unique
582class TensorDir(IntEnum):
583 PLACEHOLDER = 0
584 CONST = 1
585 INTERMEDIATE = 2
586 RESULT = 3
587
588
589class TosaSerializer:
590 def __init__(self, pathPrefix):
591
592 # Get the global TOSA version if not already defined
Kevin Chengfea5a372021-10-11 18:38:47 +0000593
594 self.builder = flatbuffers.Builder(0)
595
596 self.basicBlocks = []
597 self.startBasicBlock("main")
598 self.pathPrefix = pathPrefix
599
600 # Indicies used for adding/naming tensors
601 self.currInputIdx = 0
602 self.currConstIdx = 0
603 self.currLayerIdx = 1
604 self.currResultIdx = 0
605
606 # Is this an illegal test that is expected to fail?
607 self.expectedReturnCode = TosaReturnCode.VALID
608 self.expectedFailure = False
609 self.expectedFailureDesc = ""
610
611 def __str__(self):
612 str = ""
613 for bb in self.basicBlocks:
614 str = str + bb.__str__()
615 return str
616
617 def addPlaceholder(self, shape, dtype, vals):
618 if not self.currBasicBlock:
619 raise Exception("addTensor called without valid basic block")
620
621 name = "input-{}".format(self.currInputIdx)
622 filename = "{}.npy".format(name)
623 self.currInputIdx = self.currInputIdx + 1
624
625 tens = self.currBasicBlock.addTensor(name, shape, dtype, None, filename)
626 # This is always an input to the block
627 self.currBasicBlock.addInput(name)
628
629 if vals is not None:
630 np.save(os.path.join(self.pathPrefix, filename), vals, False)
631
632 return tens
633
634 def addConst(self, shape, dtype, vals):
635 if not self.currBasicBlock:
636 raise Exception("addTensor called without valid basic block")
637
638 name = "const-{}".format(self.currInputIdx)
639 filename = "{}.npy".format(name)
640 self.currInputIdx = self.currInputIdx + 1
641
642 tens = self.currBasicBlock.addTensor(name, shape, dtype, vals)
643 # Add the operator now
644 self.currBasicBlock.addOperator(tosa.Op.Op().CONST, [], name)
645
646 return tens
647
648 def addIntermediate(self, shape, dtype):
649
650 if not self.currBasicBlock:
651 raise Exception("addTensor called without valid basic block")
652
653 name = "layer-{}".format(self.currLayerIdx)
654 self.currLayerIdx = self.currLayerIdx + 1
655
656 tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
657
658 return tens
659
660 def addInputTensor(self, tensor):
661 self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype)
662 self.currBasicBlock.addInput(tensor.name)
663
664 def addOutputTensor(self, tensor):
665 self.currBasicBlock.addOutput(tensor.name)
666
667 def addOutput(self, shape, dtype):
668 if not self.currBasicBlock:
669 raise Exception("addTensor called without valid basic block")
670
671 name = "result-{}".format(self.currResultIdx)
672 self.currResultIdx = self.currResultIdx + 1
673
674 tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
675 self.currBasicBlock.addOutput(name)
676 return tens
677
678 def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None):
679
680 if op == tosa.Op.Op().CONST:
681 raise Exception("Use addConstTensor() to add CONST ops")
682
683 return self.currBasicBlock.addOperator(
684 op, inputs, outputs, attributes, quant_info
685 )
686
687 def setExpectedReturnCode(self, val, desc=""):
688
689 self.expectedReturnCode = val
690 self.expectedFailureDesc = desc
691
692 if val == TosaReturnCode.VALID:
693 self.expectedFailure = False
694 else:
695 # Unpredictable or error results are considered expected failures
696 # for conformance
697 self.expectedFailure = True
698
699 def serialize(self):
700
701 builder = self.builder
702
703 Version.VersionStart(builder)
704 Version.VersionAdd_major(builder, TOSA_VERSION[0])
705 Version.VersionAdd_minor(builder, TOSA_VERSION[1])
706 Version.VersionAdd_patch(builder, TOSA_VERSION[2])
Kevin Chengb97cb1d2021-10-14 11:53:39 -0700707 Version.VersionAdd_draft(builder, TOSA_VERSION[3])
Kevin Chengfea5a372021-10-11 18:38:47 +0000708 version = Version.VersionEnd(builder)
709
710 fbv_bb = TosaSerializer.serializeObjVec(
711 builder, self.basicBlocks, TosaGraph.TosaGraphStartBlocksVector
712 )
713
714 TosaGraph.TosaGraphStart(builder)
715 TosaGraph.TosaGraphAddVersion(builder, version)
716 TosaGraph.TosaGraphAddBlocks(builder, fbv_bb)
717 graph = TosaGraph.TosaGraphEnd(builder)
718
Frederick Liardet62c1e952022-08-24 10:07:25 +0100719 self.builder.Finish(graph, TOSA_GRAPH_IDENTIFIER)
Kevin Chengfea5a372021-10-11 18:38:47 +0000720 return self.builder.Output()
721
722 def writeJson(self, tosa_filename):
723 """Write a json test file so that it is fairly easy to pick up the test
724 and generate commands for third party tool"""
725 test_desc = dict()
726
727 test_desc["tosa_file"] = tosa_filename
728 ifm_name = []
729 ifm_file = []
730 ofm_name = []
731 ofm_file = []
732
733 for b in self.basicBlocks:
734 if b.name == "main":
735 for i in b.inputs:
736 ifm_name.append(i)
737 ifm_file.append(b.tensors[i].placeholderFilename)
738 for o in b.outputs:
739 ofm_name.append(o)
740 # Make up an OFM filename here. One isn't generated until the reference tool is
741 # run, so any name is a good name
742 ofm_file.append("ref-{}.npy".format(o))
743
744 test_desc["ifm_name"] = ifm_name
745 test_desc["ifm_file"] = ifm_file
746 test_desc["ofm_name"] = ofm_name
747 test_desc["ofm_file"] = ofm_file
748 test_desc["expected_return_code"] = self.expectedReturnCode
749 test_desc["expected_failure"] = self.expectedFailure
750 if self.expectedFailureDesc:
751 test_desc["expected_failure_desc"] = self.expectedFailureDesc
752
753 return json.dumps(test_desc, indent=" ")
754
755 def startBasicBlock(self, name):
756 self.currBasicBlock = TosaSerializerBasicBlock(name)
757 self.basicBlocks.append(self.currBasicBlock)
758
759 @staticmethod
760 def serializeStrVec(builder, vec, start_fcn):
761 fb_strs = [builder.CreateString(i) for i in vec]
762 start_fcn(builder, len(fb_strs))
763 for s in fb_strs[::-1]:
764 builder.PrependUOffsetTRelative(s)
765 # This try/except block supports both the Flatbuffers 2.x and 1.x APIs,
766 # defaulting to 2.x. If/when Flatbuffers 1.x support is deprecated, the
767 # try block and builder.EndVector(len) function calls can be removed.
768 try:
769 return builder.EndVector()
770 except TypeError:
771 return builder.EndVector(len(fb_strs))
772
773 @staticmethod
774 def serializeUint8Vec(builder, vec):
775 builder.StartVector(1, len(vec), 8)
776 for v in vec[::-1]:
777 builder.PrependUint8(v)
778 try:
779 return builder.EndVector()
780 except TypeError:
781 return builder.EndVector(len(vec))
782
783 @staticmethod
784 def serializeInt32Vec(builder, vec):
785 builder.StartVector(4, len(vec), 4)
786 for v in vec[::-1]:
787 builder.PrependInt32(v)
788 try:
789 return builder.EndVector()
790 except TypeError:
791 return builder.EndVector(len(vec))
792
793 @staticmethod
794 def serializeFpVec(builder, vec):
795 builder.StartVector(4, len(vec), 4)
796 for v in vec[::-1]:
797 builder.PrependFloat32(v)
798 try:
799 return builder.EndVector()
800 except TypeError:
801 return builder.EndVector(len(vec))
802
803 @staticmethod
804 def serializeObjVec(builder, vec, start_fcn):
805 serialized_vec = []
806 for v in vec[::-1]:
807 serialized_vec.append(v.serialize(builder))
808
809 start_fcn(builder, len(vec))
810 for v in serialized_vec:
811 builder.PrependUOffsetTRelative(v)
812 try:
813 return builder.EndVector()
814 except TypeError:
815 return builder.EndVector(len(vec))
816
817 @staticmethod
818 def toList(val):
819 if isinstance(val, list):
820 return val
821 else:
822 return [val]
823