blob: f0d7c63e8f3087b441eab05448ab85283f69d179 [file] [log] [blame]
Kevin Chengfea5a372021-10-11 18:38:47 +00001# Copyright (c) 2020-2021, ARM Limited.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15#!/usr/bin/env python3
16
17import os
18import sys
19import json
20import flatbuffers
21import numpy as np
22import struct
23from enum import Enum, IntEnum, unique
24from tosa import (
25 TosaGraph,
26 TosaBasicBlock,
27 TosaTensor,
28 TosaOperator,
29 DType,
30 Op,
31 ResizeMode,
32 Version,
33)
34from tosa_ref_run import TosaReturnCode
35
36import tosa
37
Kevin Chengb97cb1d2021-10-14 11:53:39 -070038TOSA_VERSION_MAJOR = 0
39TOSA_VERSION_MINOR = 23
40TOSA_VERSION_PATCH = 0
41TOSA_VERSION_DRAFT = True
42TOSA_VERSION = [TOSA_VERSION_MAJOR,
43 TOSA_VERSION_MINOR,
44 TOSA_VERSION_PATCH,
45 TOSA_VERSION_DRAFT]
Kevin Chengfea5a372021-10-11 18:38:47 +000046# With the way flatc generates its python types, there is no programatic way
47# to get string names for the integer types. Manually maintain a string table
48# here.
49DType = tosa.DType.DType()
50DTypeNames = [
51 "UNKNOWN",
52 "BOOL",
53 "UINT8",
54 "INT4",
55 "INT8",
56 "INT16",
57 "INT32",
58 "INT48",
59 "FLOAT",
60]
61
62ByteMask = np.uint64(0xFF)
63
64
65def dtype_str_to_val(name):
66
67 for i in range(len(DTypeNames)):
68 if name.casefold() == DTypeNames[i].casefold():
69 return i
70 raise Exception("Unable to parse DType name {}".format(name))
71
72
73class TosaSerializerUnion:
74 """This class handles encapsulating and serializing union types into flatbuffers"""
75
76 def __init__(self):
77
78 # A tuple of the start and end functions. Set by the options constructors below
79 self.optFcns = None
80
81 # The type from the tosa.Options enumeration. Set by the options constructors below.
82 self.utype = None
83
84 # Each of these lists is a tuple of the add function and the
85 # value being added. Set by the options constructors below.
86 self.ints = []
87 self.bools = []
88 self.floats = []
89 self.strings = []
90 self.intvecs = []
91 self.fpvecs = []
92
93 def serialize(self, builder):
94
95 # We have to build strings and vectors first
96 strList = []
97 intVecList = []
98 fpVecList = []
99
100 for fcn, val in self.strings:
101 strList.append((fcn, builder.CreateString(val)))
102
103 for fcn, val in self.intvecs:
104 intVecList.append((fcn, TosaSerializer.serializeInt32Vec(builder, val)))
105
106 for fcn, val in self.fpvecs:
107 fpVecList.append((fcn, TosaSerializer.serializeFpVec(builder, val)))
108
109 startFcn, endFcn = self.optFcns
110
111 # Then serialize the options object from the list of primitives and
112 # other serialized values
113 startFcn(builder)
114 for fcn, val in self.ints:
115 fcn(builder, val)
116
117 for fcn, val in self.bools:
118 fcn(builder, val)
119
120 for fcn, val in self.floats:
121 fcn(builder, val)
122
123 for fcn, val in strList:
124 fcn(builder, val)
125
126 for fcn, val in intVecList:
127 fcn(builder, val)
128
129 for fcn, val in fpVecList:
130 fcn(builder, val)
131
132 return endFcn(builder)
133
134
135class TosaSerializerAttribute(TosaSerializerUnion):
136 """This class handles encapsulating all of the enumerated types for attributes"""
137
138 def __init__(self):
139 super().__init__()
140
141 def PoolAttribute(self, kernel, stride, padding):
142 from tosa import PoolAttribute as a, Attribute
143
144 self.utype = Attribute.Attribute().PoolAttribute
145
146 self.optFcns = (a.PoolAttributeStart, a.PoolAttributeEnd)
147 self.intvecs.append((a.PoolAttributeAddPadding, padding))
148 self.intvecs.append((a.PoolAttributeAddKernel, kernel))
149 self.intvecs.append((a.PoolAttributeAddStride, stride))
150
151 def ConvAttribute(self, padding, stride, dilation):
152 from tosa import ConvAttribute as a, Attribute
153
154 self.utype = Attribute.Attribute().ConvAttribute
155 self.optFcns = (a.ConvAttributeStart, a.ConvAttributeEnd)
156
157 self.intvecs.append((a.ConvAttributeAddPadding, padding))
158 self.intvecs.append((a.ConvAttributeAddStride, stride))
159 self.intvecs.append((a.ConvAttributeAddDilation, dilation))
160
161 def TransposeConvAttribute(self, outpad, stride, dilation, output_shape):
162 from tosa import TransposeConvAttribute as a, Attribute
163
164 self.utype = Attribute.Attribute().TransposeConvAttribute
165 self.optFcns = (a.TransposeConvAttributeStart, a.TransposeConvAttributeEnd)
166
167 self.intvecs.append((a.TransposeConvAttributeAddOutpad, outpad))
168 self.intvecs.append((a.TransposeConvAttributeAddStride, stride))
169 self.intvecs.append((a.TransposeConvAttributeAddDilation, dilation))
170 self.intvecs.append((a.TransposeConvAttributeAddOutputShape, output_shape))
171
172 def ReluNAttribute(self, maxint, maxfp):
173 from tosa import ReluNAttribute as a, Attribute
174
175 self.utype = Attribute.Attribute().ReluNAttribute
176 self.optFcns = (a.ReluNAttributeStart, a.ReluNAttributeEnd)
177
178 self.ints.append((a.ReluNAttributeAddMaxInt, maxint))
179 self.ints.append((a.ReluNAttributeAddMaxFp, maxfp))
180
181 def AxisAttribute(self, axis):
182 from tosa import AxisAttribute as a, Attribute
183
184 self.utype = Attribute.Attribute().AxisAttribute
185 self.optFcns = (a.AxisAttributeStart, a.AxisAttributeEnd)
186
187 self.ints.append((a.AxisAttributeAddAxis, axis))
188
189 def ReshapeAttribute(self, shape):
190 from tosa import ReshapeAttribute as a, Attribute
191
192 self.utype = Attribute.Attribute().ReshapeAttribute
193 self.optFcns = (a.ReshapeAttributeStart, a.ReshapeAttributeEnd)
194
195 self.intvecs.append((a.ReshapeAttributeAddShape, shape))
196
197 def SliceAttribute(self, begin, size):
198 from tosa import SliceAttribute as a, Attribute
199
200 self.utype = Attribute.Attribute().SliceAttribute
201 self.optFcns = (a.SliceAttributeStart, a.SliceAttributeEnd)
202
203 self.intvecs.append((a.SliceAttributeAddBegin, begin))
204 self.intvecs.append((a.SliceAttributeAddSize, size))
205
206 def TileAttribute(self, multiples):
207 from tosa import TileAttribute as a, Attribute
208
209 self.utype = Attribute.Attribute().TileAttribute
210 self.optFcns = (a.TileAttributeStart, a.TileAttributeEnd)
211
212 self.intvecs.append((a.TileAttributeAddMultiples, multiples))
213
214 def ResizeAttribute(
215 self, output_size, stride, offset, shift, stride_fp, offset_fp, mode
216 ):
217 from tosa import ResizeAttribute as a, Attribute
218
219 self.utype = Attribute.Attribute().ResizeAttribute
220 self.optFcns = (a.ResizeAttributeStart, a.ResizeAttributeEnd)
221
222 self.intvecs.append((a.ResizeAttributeAddOutputSize, output_size))
223 self.intvecs.append((a.ResizeAttributeAddStride, stride))
224 self.intvecs.append((a.ResizeAttributeAddOffset, offset))
225 self.ints.append((a.ResizeAttributeAddShift, shift))
226 self.fpvecs.append((a.ResizeAttributeAddStrideFp, stride_fp))
227 self.fpvecs.append((a.ResizeAttributeAddOffsetFp, offset_fp))
228 self.ints.append((a.ResizeAttributeAddMode, mode))
229
230 def ClampAttribute(self, minint, maxint, minfp, maxfp):
231 from tosa import ClampAttribute as a, Attribute
232
233 self.utype = Attribute.Attribute().ClampAttribute
234 self.optFcns = (a.ClampAttributeStart, a.ClampAttributeEnd)
235
236 self.ints.append((a.ClampAttributeAddMinInt, minint))
237 self.ints.append((a.ClampAttributeAddMaxInt, maxint))
238
239 self.ints.append((a.ClampAttributeAddMinFp, minfp))
240 self.ints.append((a.ClampAttributeAddMaxFp, maxfp))
241
242 def RescaleAttribute(
243 self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel
244 ):
245 from tosa import RescaleAttribute as a, Attribute
246
247 self.utype = Attribute.Attribute().RescaleAttribute
248 self.optFcns = (a.RescaleAttributeStart, a.RescaleAttributeEnd)
249
250 self.ints.append((a.RescaleAttributeAddInputZp, input_zp))
251 self.ints.append((a.RescaleAttributeAddOutputZp, output_zp))
252 self.intvecs.append((a.RescaleAttributeAddMultiplier, multiplier))
253 self.intvecs.append((a.RescaleAttributeAddShift, shift))
254 self.bools.append((a.RescaleAttributeAddScale32, scale32))
255 self.bools.append((a.RescaleAttributeAddDoubleRound, double_round))
256 self.bools.append((a.RescaleAttributeAddPerChannel, per_channel))
257
258 def MulAttribute(self, shift):
259 from tosa import MulAttribute as a, Attribute
260
261 self.utype = Attribute.Attribute().MulAttribute
262 self.optFcns = (a.MulAttributeStart, a.MulAttributeEnd)
263
264 self.ints.append((a.MulAttributeAddShift, shift))
265
266 def ArithmeticRightShiftAttribute(self, round):
267 from tosa import ArithmeticRightShiftAttribute as a, Attribute
268
269 self.utype = Attribute.Attribute().ArithmeticRightShiftAttribute
270 self.optFcns = (
271 a.ArithmeticRightShiftAttributeStart,
272 a.ArithmeticRightShiftAttributeEnd,
273 )
274
275 self.bools.append((a.ArithmeticRightShiftAttributeAddRound, round))
276
277 def CustomAttribute(self, identifier):
278 from tosa import CustomAttribute as a, Attribute
279
280 self.utype = Attribute.Attribute().CustomAttribute
281 self.optFcns = (a.CustomAttributeStart, a.CustomAttributeEnd)
282
283 self.strings.append((a.CustomAttributeAddIdentifier, identifier))
284
285 def CondIfAttribute(self, then_branch, else_branch):
286 from tosa import CondIfAttribute as a, Attribute
287
288 self.utype = Attribute.Attribute().CondIfAttribute
289 self.optFcns = (a.CondIfAttributeStart, a.CondIfAttributeEnd)
290
291 self.strings.append((a.CondIfAttributeAddThenBranch, then_branch))
292 self.strings.append((a.CondIfAttributeAddElseBranch, else_branch))
293
294 def WhileLoopAttribute(self, cond_branch, body_branch):
295 from tosa import WhileLoopAttribute as a, Attribute
296
297 self.utype = Attribute.Attribute().WhileLoopAttribute
298 self.optFcns = (a.WhileLoopAttributeStart, a.WhileLoopAttributeEnd)
299
300 self.strings.append((a.WhileLoopAttributeAddCondBranch, cond_branch))
301 self.strings.append((a.WhileLoopAttributeAddBodyBranch, body_branch))
302
303
304class TosaSerializerQuantInfo(TosaSerializerUnion):
305 """This class handles encapsulating all of the enumerated types for quantinfo types"""
306
307 def __init__(self):
308 super().__init__()
309
310 def ConvQuantInfo(self, input_zp, weight_zp):
311 from tosa import ConvQuantInfo as q, QuantInfo
312
313 self.utype = QuantInfo.QuantInfo().ConvQuantInfo
314 self.optFcns = (q.ConvQuantInfoStart, q.ConvQuantInfoEnd)
315 self.ints.append((q.ConvQuantInfoAddInputZp, input_zp))
316 self.ints.append((q.ConvQuantInfoAddWeightZp, weight_zp))
317
318 def UnaryQuantInfo(self, input_zp, output_zp):
319 from tosa import UnaryQuantInfo as q, QuantInfo
320
321 self.utype = QuantInfo.QuantInfo().UnaryQuantInfo
322 self.optFcns = (q.UnaryQuantInfoStart, q.UnaryQuantInfoEnd)
323 self.ints.append((q.UnaryQuantInfoAddInputZp, input_zp))
324 self.ints.append((q.UnaryQuantInfoAddOutputZp, output_zp))
325
326 def MatMulQuantInfo(self, a_zp, b_zp):
327 from tosa import MatMulQuantInfo as q, QuantInfo
328
329 self.utype = QuantInfo.QuantInfo().MatMulQuantInfo
330 self.optFcns = (q.MatMulQuantInfoStart, q.MatMulQuantInfoEnd)
331 self.ints.append((q.MatMulQuantInfoAddAZp, a_zp))
332 self.ints.append((q.MatMulQuantInfoAddBZp, b_zp))
333
334 def PadQuantInfo(self, input_zp):
335 from tosa import PadQuantInfo as q, QuantInfo
336
337 self.utype = QuantInfo.QuantInfo().PadQuantInfo
338 self.optFcns = (q.PadQuantInfoStart, q.PadQuantInfoEnd)
339 self.ints.append((q.PadQuantInfoAddInputZp, input_zp))
340
341
342class TosaSerializerTensor:
343 def __init__(
344 self,
345 name,
346 shape,
347 dtype,
348 data=None,
349 placeholderFilename=None,
350 ):
351 self.name = name
352
353 if isinstance(shape, np.ndarray):
354 shape = shape.astype(int).tolist()
355 shape = list(map(int, shape))
356
357 self.shape = shape
358 self.dtype = dtype
359
360 if isinstance(data, np.ndarray):
361 data = data.flatten().astype(int).tolist()
362 data = list(map(int, data))
363 self.data = data
364 elif isinstance(data, list):
365 data = list(map(int, data))
366 self.data = data
367 else:
368 self.data = None
369
370 # Filename for placeholder tensors. These get generated by the test generation
371 # process and are written to disk, but are considered input tensors by the network
372 # so they do not appear in the TOSA serialiazation. However, if we want to form a unit
373 # test around these input tensors, we can get the filename from here.
374 self.placeholderFilename = placeholderFilename
375
376 def __str__(self):
377 str = "TosaSerializerTensor name: {} shape: {} dtype: {}".format(
378 self.name,
379 self.shape,
380 DTypeNames[self.dtype],
381 )
382 return str
383
384 def setDtype(self, dtype):
385 self.dtype = dtype
386
387 def serialize(self, builder):
388 fb_name = builder.CreateString(self.name)
389 fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
390 if self.data:
391 u8_data = list()
392 # little endianess
393 if self.dtype == DType.BOOL:
394 for val in self.data:
395 val_u8 = np.uint8(val)
396 u8_data.append(val_u8)
397 elif self.dtype == DType.INT4:
398 in_size = len(self.data)
399 out_size = (in_size + 1) // 2
400 for i in range(out_size):
401 val_0 = self.data[2 * i]
402 if (2 * i + 1) < in_size:
403 val_1 = self.data[2 * i + 1]
404 else:
405 val_1 = 0
406 val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4)
407 val_u8 = np.uint8(val_i8)
408 u8_data.append(val_u8)
409 elif self.dtype == DType.INT8:
410 for val in self.data:
411 val_u8 = np.uint8(val)
412 u8_data.append(val_u8)
413 elif self.dtype == DType.INT16:
414 for val in self.data:
415 val_u16 = np.uint16(val)
416 b0 = val_u16 & ByteMask
417 b1 = (val_u16 >> np.uint16(8)) & ByteMask
418 u8_data.extend([b0, b1])
419 elif self.dtype == DType.INT32:
420 for val in self.data:
421 val_u32 = np.uint32(val)
422 b0 = val_u32 & ByteMask
423 b1 = (val_u32 >> np.uint32(8)) & ByteMask
424 b2 = (val_u32 >> np.uint32(16)) & ByteMask
Kevin Cheng6b078ca2021-10-13 23:12:50 -0700425 b3 = (val_u32 >> np.uint32(24)) & ByteMask
Kevin Chengfea5a372021-10-11 18:38:47 +0000426 u8_data.extend([b0, b1, b2, b3])
427 elif self.dtype == DType.INT48:
428 for val in self.data:
429 val_u64 = np.uint64(val)
430 b0 = val_u64 & ByteMask
431 b1 = (val_u64 >> np.uint64(8)) & ByteMask
432 b2 = (val_u64 >> np.uint64(16)) & ByteMask
433 b3 = (val_u64 >> np.uint64(24)) & ByteMask
434 b4 = (val_u64 >> np.uint64(32)) & ByteMask
435 b5 = (val_u64 >> np.uint64(40)) & ByteMask
436 u8_data.extend([b0, b1, b2, b3, b4, b5])
437 elif self.dtype == DType.FLOAT:
438 for val in self.data:
439 b = struct.pack("!f", val)
440 u8_data.extend([b[3], b[2], b[1], b[0]])
441 else:
442 raise Exception(
443 "unsupported data type {}".format(DTypeNames[self.dtype])
444 )
445 fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data)
446
447 TosaTensor.TosaTensorStart(builder)
448 TosaTensor.TosaTensorAddName(builder, fb_name)
449 TosaTensor.TosaTensorAddShape(builder, fb_shapes)
450 TosaTensor.TosaTensorAddType(builder, self.dtype)
451 if self.data:
452 TosaTensor.TosaTensorAddData(builder, fb_data)
453
454 return TosaTensor.TosaTensorEnd(builder)
455
456
457class TosaSerializerOperator:
458 def __init__(self, op, inputs, outputs, attributes=None, quantInfo=None):
459 self.op = op
460 self.attributes = attributes
461 self.inputs = TosaSerializer.toList(inputs)
462 self.outputs = TosaSerializer.toList(outputs)
463 self.quantInfo = quantInfo
464
465 def __str__(self):
466 str = "Op {}\n----\n".format(self.op)
467
468 for i in self.inputs:
469 str = str + " Input: {}\n".format(i)
470 for o in self.outputs:
471 str = str + " Output: {}\n".format(o)
472
473 return str
474
475 def serialize(self, builder):
476 fb_inputs = TosaSerializer.serializeStrVec(
477 builder, self.inputs, TosaOperator.TosaOperatorStartInputsVector
478 )
479 fb_outputs = TosaSerializer.serializeStrVec(
480 builder, self.outputs, TosaOperator.TosaOperatorStartOutputsVector
481 )
482 # Need to serialize quant_info and attributes enums still
483 if self.attributes is not None:
484 fb_attributes = self.attributes.serialize(builder)
485
486 if self.quantInfo is not None:
487 fb_qinfo = self.quantInfo.serialize(builder)
488
489 TosaOperator.TosaOperatorStart(builder)
490 TosaOperator.TosaOperatorAddOp(builder, self.op)
491 TosaOperator.TosaOperatorAddInputs(builder, fb_inputs)
492 TosaOperator.TosaOperatorAddOutputs(builder, fb_outputs)
493 if self.attributes is not None:
494 TosaOperator.TosaOperatorAddAttributeType(builder, self.attributes.utype)
495 TosaOperator.TosaOperatorAddAttribute(builder, fb_attributes)
496 if self.quantInfo is not None:
497 TosaOperator.TosaOperatorAddQuantInfoType(builder, self.quantInfo.utype)
498 TosaOperator.TosaOperatorAddQuantInfo(builder, fb_qinfo)
499
500 return TosaOperator.TosaOperatorEnd(builder)
501
502
503class TosaSerializerBasicBlock:
504 def __init__(self, name):
505 self.name = name
506 self.operators = []
507
508 # Dict assures uniqueness, but allows us to look up by name
509 self.tensors = dict()
510
511 self.inputs = []
512 self.outputs = []
513
514 def addTensor(
515 self,
516 name,
517 shape,
518 dtype,
519 data=None,
520 placeholderFilename=None,
521 ):
522 try:
523 # Someone already added this tensor.
524 tens = self.tensors[name]
525 except KeyError:
526 self.tensors[name] = TosaSerializerTensor(
527 name, shape, dtype, data, placeholderFilename
528 )
529
530 return self.tensors[name]
531
532 def addInput(self, name):
533 self.inputs.append(name)
534
535 def addOutput(self, name):
536 self.outputs.append(name)
537
538 def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None):
539 self.operators.append(
540 TosaSerializerOperator(op, inputs, outputs, attributes, quant_info)
541 )
542
543 def serialize(self, builder):
544 fb_name = builder.CreateString(self.name)
545 fbv_inputs = TosaSerializer.serializeStrVec(
546 builder, list(self.inputs), TosaBasicBlock.TosaBasicBlockStartInputsVector
547 )
548 fbv_outputs = TosaSerializer.serializeStrVec(
549 builder, list(self.outputs), TosaBasicBlock.TosaBasicBlockStartOutputsVector
550 )
551 fbv_tensors = TosaSerializer.serializeObjVec(
552 builder,
553 list(self.tensors.values()),
554 TosaBasicBlock.TosaBasicBlockStartTensorsVector,
555 )
556 fbv_operators = TosaSerializer.serializeObjVec(
557 builder, self.operators, TosaBasicBlock.TosaBasicBlockStartOperatorsVector
558 )
559
560 TosaBasicBlock.TosaBasicBlockStart(builder)
561 TosaBasicBlock.TosaBasicBlockAddName(builder, fb_name)
562 TosaBasicBlock.TosaBasicBlockAddInputs(builder, fbv_inputs)
563 TosaBasicBlock.TosaBasicBlockAddOutputs(builder, fbv_outputs)
564 TosaBasicBlock.TosaBasicBlockAddTensors(builder, fbv_tensors)
565 TosaBasicBlock.TosaBasicBlockAddOperators(builder, fbv_operators)
566 return TosaBasicBlock.TosaBasicBlockEnd(builder)
567
568
569@unique
570class TensorDir(IntEnum):
571 PLACEHOLDER = 0
572 CONST = 1
573 INTERMEDIATE = 2
574 RESULT = 3
575
576
577class TosaSerializer:
578 def __init__(self, pathPrefix):
579
580 # Get the global TOSA version if not already defined
Kevin Chengfea5a372021-10-11 18:38:47 +0000581
582 self.builder = flatbuffers.Builder(0)
583
584 self.basicBlocks = []
585 self.startBasicBlock("main")
586 self.pathPrefix = pathPrefix
587
588 # Indicies used for adding/naming tensors
589 self.currInputIdx = 0
590 self.currConstIdx = 0
591 self.currLayerIdx = 1
592 self.currResultIdx = 0
593
594 # Is this an illegal test that is expected to fail?
595 self.expectedReturnCode = TosaReturnCode.VALID
596 self.expectedFailure = False
597 self.expectedFailureDesc = ""
598
599 def __str__(self):
600 str = ""
601 for bb in self.basicBlocks:
602 str = str + bb.__str__()
603 return str
604
605 def addPlaceholder(self, shape, dtype, vals):
606 if not self.currBasicBlock:
607 raise Exception("addTensor called without valid basic block")
608
609 name = "input-{}".format(self.currInputIdx)
610 filename = "{}.npy".format(name)
611 self.currInputIdx = self.currInputIdx + 1
612
613 tens = self.currBasicBlock.addTensor(name, shape, dtype, None, filename)
614 # This is always an input to the block
615 self.currBasicBlock.addInput(name)
616
617 if vals is not None:
618 np.save(os.path.join(self.pathPrefix, filename), vals, False)
619
620 return tens
621
622 def addConst(self, shape, dtype, vals):
623 if not self.currBasicBlock:
624 raise Exception("addTensor called without valid basic block")
625
626 name = "const-{}".format(self.currInputIdx)
627 filename = "{}.npy".format(name)
628 self.currInputIdx = self.currInputIdx + 1
629
630 tens = self.currBasicBlock.addTensor(name, shape, dtype, vals)
631 # Add the operator now
632 self.currBasicBlock.addOperator(tosa.Op.Op().CONST, [], name)
633
634 return tens
635
636 def addIntermediate(self, shape, dtype):
637
638 if not self.currBasicBlock:
639 raise Exception("addTensor called without valid basic block")
640
641 name = "layer-{}".format(self.currLayerIdx)
642 self.currLayerIdx = self.currLayerIdx + 1
643
644 tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
645
646 return tens
647
648 def addInputTensor(self, tensor):
649 self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype)
650 self.currBasicBlock.addInput(tensor.name)
651
652 def addOutputTensor(self, tensor):
653 self.currBasicBlock.addOutput(tensor.name)
654
655 def addOutput(self, shape, dtype):
656 if not self.currBasicBlock:
657 raise Exception("addTensor called without valid basic block")
658
659 name = "result-{}".format(self.currResultIdx)
660 self.currResultIdx = self.currResultIdx + 1
661
662 tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
663 self.currBasicBlock.addOutput(name)
664 return tens
665
666 def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None):
667
668 if op == tosa.Op.Op().CONST:
669 raise Exception("Use addConstTensor() to add CONST ops")
670
671 return self.currBasicBlock.addOperator(
672 op, inputs, outputs, attributes, quant_info
673 )
674
675 def setExpectedReturnCode(self, val, desc=""):
676
677 self.expectedReturnCode = val
678 self.expectedFailureDesc = desc
679
680 if val == TosaReturnCode.VALID:
681 self.expectedFailure = False
682 else:
683 # Unpredictable or error results are considered expected failures
684 # for conformance
685 self.expectedFailure = True
686
687 def serialize(self):
688
689 builder = self.builder
690
691 Version.VersionStart(builder)
692 Version.VersionAdd_major(builder, TOSA_VERSION[0])
693 Version.VersionAdd_minor(builder, TOSA_VERSION[1])
694 Version.VersionAdd_patch(builder, TOSA_VERSION[2])
Kevin Chengb97cb1d2021-10-14 11:53:39 -0700695 Version.VersionAdd_draft(builder, TOSA_VERSION[3])
Kevin Chengfea5a372021-10-11 18:38:47 +0000696 version = Version.VersionEnd(builder)
697
698 fbv_bb = TosaSerializer.serializeObjVec(
699 builder, self.basicBlocks, TosaGraph.TosaGraphStartBlocksVector
700 )
701
702 TosaGraph.TosaGraphStart(builder)
703 TosaGraph.TosaGraphAddVersion(builder, version)
704 TosaGraph.TosaGraphAddBlocks(builder, fbv_bb)
705 graph = TosaGraph.TosaGraphEnd(builder)
706
707 self.builder.Finish(graph)
708 return self.builder.Output()
709
710 def writeJson(self, tosa_filename):
711 """Write a json test file so that it is fairly easy to pick up the test
712 and generate commands for third party tool"""
713 test_desc = dict()
714
715 test_desc["tosa_file"] = tosa_filename
716 ifm_name = []
717 ifm_file = []
718 ofm_name = []
719 ofm_file = []
720
721 for b in self.basicBlocks:
722 if b.name == "main":
723 for i in b.inputs:
724 ifm_name.append(i)
725 ifm_file.append(b.tensors[i].placeholderFilename)
726 for o in b.outputs:
727 ofm_name.append(o)
728 # Make up an OFM filename here. One isn't generated until the reference tool is
729 # run, so any name is a good name
730 ofm_file.append("ref-{}.npy".format(o))
731
732 test_desc["ifm_name"] = ifm_name
733 test_desc["ifm_file"] = ifm_file
734 test_desc["ofm_name"] = ofm_name
735 test_desc["ofm_file"] = ofm_file
736 test_desc["expected_return_code"] = self.expectedReturnCode
737 test_desc["expected_failure"] = self.expectedFailure
738 if self.expectedFailureDesc:
739 test_desc["expected_failure_desc"] = self.expectedFailureDesc
740
741 return json.dumps(test_desc, indent=" ")
742
743 def startBasicBlock(self, name):
744 self.currBasicBlock = TosaSerializerBasicBlock(name)
745 self.basicBlocks.append(self.currBasicBlock)
746
747 @staticmethod
748 def serializeStrVec(builder, vec, start_fcn):
749 fb_strs = [builder.CreateString(i) for i in vec]
750 start_fcn(builder, len(fb_strs))
751 for s in fb_strs[::-1]:
752 builder.PrependUOffsetTRelative(s)
753 # This try/except block supports both the Flatbuffers 2.x and 1.x APIs,
754 # defaulting to 2.x. If/when Flatbuffers 1.x support is deprecated, the
755 # try block and builder.EndVector(len) function calls can be removed.
756 try:
757 return builder.EndVector()
758 except TypeError:
759 return builder.EndVector(len(fb_strs))
760
761 @staticmethod
762 def serializeUint8Vec(builder, vec):
763 builder.StartVector(1, len(vec), 8)
764 for v in vec[::-1]:
765 builder.PrependUint8(v)
766 try:
767 return builder.EndVector()
768 except TypeError:
769 return builder.EndVector(len(vec))
770
771 @staticmethod
772 def serializeInt32Vec(builder, vec):
773 builder.StartVector(4, len(vec), 4)
774 for v in vec[::-1]:
775 builder.PrependInt32(v)
776 try:
777 return builder.EndVector()
778 except TypeError:
779 return builder.EndVector(len(vec))
780
781 @staticmethod
782 def serializeFpVec(builder, vec):
783 builder.StartVector(4, len(vec), 4)
784 for v in vec[::-1]:
785 builder.PrependFloat32(v)
786 try:
787 return builder.EndVector()
788 except TypeError:
789 return builder.EndVector(len(vec))
790
791 @staticmethod
792 def serializeObjVec(builder, vec, start_fcn):
793 serialized_vec = []
794 for v in vec[::-1]:
795 serialized_vec.append(v.serialize(builder))
796
797 start_fcn(builder, len(vec))
798 for v in serialized_vec:
799 builder.PrependUOffsetTRelative(v)
800 try:
801 return builder.EndVector()
802 except TypeError:
803 return builder.EndVector(len(vec))
804
805 @staticmethod
806 def toList(val):
807 if isinstance(val, list):
808 return val
809 else:
810 return [val]
811