blob: 07e0e1a7e7aeace2746c1d6a460ea83c2896aa09 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2
3# Copyright (c) 2020, ARM Limited.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17#!/usr/bin/env python3
18
19import flatbuffers
20import numpy as np
21from enum import Enum, IntEnum, unique
22from tosa import TosaGraph, TosaBasicBlock, TosaTensor, TosaOperator, DType, Format, Usage, Op, ResizeMode, Version
23import tosa
24import os
25import json
26
27# With the way flatc generates its python types, there is no programatic way
28# to get string names for the integer types. Manually maintain a string table
29# here.
30DTypeNames = [ 'UNKNOWN',
31 'BOOL',
32 'AINT8',
33 'UINT8',
34 'INT4',
35 'INT8',
36 'INT16',
37 'INT32',
38 'INT48',
39 'FLOAT' ]
40
41def dtype_str_to_val(name):
42
43 for i in range(len(DTypeNames)):
44 if name.casefold() == DTypeNames[i].casefold():
45 return i
46 raise Exception('Unable to parse DType name {}'.format(name))
47
48
49class TosaSerializerUnion:
50 '''This class handles encapsulating and serializing union types into flatbuffers'''
51 def __init__(self):
52
53 # A tuple of the start and end functions. Set by the options constructors below
54 self.optFcns = None
55
56 # The type from the tosa.Options enumeration. Set by the options constructors below.
57 self.utype = None
58
59 # Each of these lists is a tuple of the add function and the
60 # value being added. Set by the options constructors below.
61 self.ints = []
62 self.bools = []
63 self.floats = []
64 self.strings = []
65 self.intvecs = []
66
67 def serialize(self, builder):
68
69 # We have to build strings and vectors first
70 strList = []
71 intVecList = []
72
73 for fcn, val in self.strings:
74 strList.append((fcn, builder.CreateString(val)))
75
76 for fcn, val in self.intvecs:
77 intVecList.append((fcn, TosaSerializer.serializeInt32Vec(builder, val)))
78
79 startFcn, endFcn = self.optFcns
80
81 # Then serialize the options object from the list of primitives and
82 # other serialized values
83 startFcn(builder)
84 for fcn, val in self.ints:
85 fcn(builder, val)
86
87 for fcn, val in self.bools:
88 fcn(builder, val)
89
90 for fcn, val in self.floats:
91 fcn(builder, val)
92
93 for fcn, val in strList:
94 fcn(builder, val)
95
96 for fcn, val in intVecList:
97 fcn(builder, val)
98
99 return endFcn(builder)
100
101class TosaSerializerAttribute(TosaSerializerUnion):
102 '''This class handles encapsulating all of the enumerated types for attributes'''
103
104 def __init__(self):
105 super().__init__()
106
107 def Pool2dAttribute(self, kernel, stride, padding):
108 from tosa import Pool2dAttribute as a, Attribute
109
110 self.utype = Attribute.Attribute().Pool2dAttribute
111
112 self.optFcns = (a.Pool2dAttributeStart, a.Pool2dAttributeEnd)
113 self.intvecs.append((a.Pool2dAttributeAddPadding,
114 padding))
115 self.intvecs.append((a.Pool2dAttributeAddKernel,
116 kernel))
117 self.intvecs.append((a.Pool2dAttributeAddStride,
118 stride))
119
120 def Conv2dAttribute(self, padding, stride, dilation):
121 from tosa import Conv2dAttribute as a, Attribute
122
123 self.utype = Attribute.Attribute().Conv2dAttribute
124 self.optFcns = (a.Conv2dAttributeStart, a.Conv2dAttributeEnd)
125
126 self.intvecs.append((a.Conv2dAttributeAddPadding,
127 padding))
128 self.intvecs.append((a.Conv2dAttributeAddStride,
129 stride))
130 self.intvecs.append((a.Conv2dAttributeAddDilation,
131 dilation))
132
133 def TransposeConv2DAttribute(self, outpad, stride, dilation, output_shape):
134 from tosa import TransposeConv2dAttribute as a, Attribute
135
136 self.utype = Attribute.Attribute().TransposeConv2dAttribute
137 self.optFcns = (a.TransposeConv2dAttributeStart, a.TransposeConv2dAttributeEnd)
138
139 self.intvecs.append((a.TransposeConv2dAttributeAddOutpad,
140 outpad))
141 self.intvecs.append((a.TransposeConv2dAttributeAddStride,
142 stride))
143 self.intvecs.append((a.TransposeConv2dAttributeAddDilation,
144 dilation))
145 self.intvecs.append((a.TransposeConv2dAttributeAddOutputShape,
146 output_shape))
147
148 def ReluNAttribute(self, maxint, maxfp):
149 from tosa import ReluNAttribute as a, Attribute
150
151 self.utype = Attribute.Attribute().ReluNAttribute
152 self.optFcns = (a.ReluNAttributeStart, a.ReluNAttributeEnd)
153
154 self.ints.append((a.ReluNAttributeAddMaxInt, maxint))
155 self.ints.append((a.ReluNAttributeAddMaxFp, maxfp))
156
157
158 def AxisAttribute(self, axis):
159 from tosa import AxisAttribute as a, Attribute
160
161 self.utype = Attribute.Attribute().AxisAttribute
162 self.optFcns = (a.AxisAttributeStart, a.AxisAttributeEnd)
163
164 self.ints.append((a.AxisAttributeAddAxis,
165 axis))
166
167 def ReshapeAttribute(self, shape):
168 from tosa import ReshapeAttribute as a, Attribute
169
170 self.utype = Attribute.Attribute().ReshapeAttribute
171 self.optFcns = (a.ReshapeAttributeStart, a.ReshapeAttributeEnd)
172
173 self.intvecs.append((a.ReshapeAttributeAddShape,
174 shape))
175
176 def SliceAttribute(self, begin, size):
177 from tosa import SliceAttribute as a, Attribute
178
179 self.utype = Attribute.Attribute().SliceAttribute
180 self.optFcns = (a.SliceAttributeStart, a.SliceAttributeEnd)
181
182 self.intvecs.append((a.SliceAttributeAddBegin,
183 begin))
184 self.intvecs.append((a.SliceAttributeAddSize,
185 size))
186
187 def TileAttribute(self, multiples):
188 from tosa import TileAttribute as a, Attribute
189
190 self.utype = Attribute.Attribute().TileAttribute
191 self.optFcns = (a.TileAttributeStart, a.TileAttributeEnd)
192
193 self.intvecs.append((a.TileAttributeAddMultiples,
194 multiples))
195
196 def ResizeAttribute(self, output_size, stride, offset, shift, mode):
197 from tosa import ResizeAttribute as a, Attribute
198
199 self.utype = Attribute.Attribute().ResizeAttribute
200 self.optFcns = (a.ResizeAttributeStart, a.ResizeAttributeEnd)
201
202 self.intvecs.append((a.ResizeAttributeAddOutputSize,
203 output_size))
204 self.intvecs.append((a.ResizeAttributeAddStride,
205 stride))
206 self.intvecs.append((a.ResizeAttributeAddOffset,
207 offset))
208 self.ints.append((a.ResizeAttributeAddShift,
209 shift))
210 self.ints.append((a.ResizeAttributeAddMode,
211 mode))
212
213 def ClampAttribute(self, minint, maxint, minfp, maxfp):
214 from tosa import ClampAttribute as a, Attribute
215
216 self.utype = Attribute.Attribute().ClampAttribute
217 self.optFcns = (a.ClampAttributeStart, a.ClampAttributeEnd)
218
219 self.ints.append((a.ClampAttributeAddMinInt,
220 minint))
221 self.ints.append((a.ClampAttributeAddMaxInt,
222 maxint))
223
224 self.ints.append((a.ClampAttributeAddMinFp,
225 minfp))
226 self.ints.append((a.ClampAttributeAddMaxFp,
227 maxfp))
228
229 def RescaleAttribute(self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel):
230 from tosa import RescaleAttribute as a, Attribute
231
232 self.utype = Attribute.Attribute().RescaleAttribute
233 self.optFcns = (a.RescaleAttributeStart, a.RescaleAttributeEnd)
234
235 self.ints.append((a.RescaleAttributeAddInputZp,
236 input_zp))
237 self.ints.append((a.RescaleAttributeAddOutputZp,
238 output_zp))
239 self.intvecs.append((a.RescaleAttributeAddMultiplier,
240 multiplier))
241 self.intvecs.append((a.RescaleAttributeAddShift,
242 shift))
243 self.bools.append((a.RescaleAttributeAddScale32,
244 scale32))
245 self.bools.append((a.RescaleAttributeAddDoubleRound,
246 double_round))
247 self.bools.append((a.RescaleAttributeAddPerChannel,
248 per_channel))
249
Kevin Chengaee1fac2020-11-11 13:54:06 -0800250 def MulAttribute(self, shift):
251 from tosa import MulAttribute as a, Attribute
252
253 self.utype = Attribute.Attribute().MulAttribute
254 self.optFcns = (a.MulAttributeStart, a.MulAttributeEnd)
255
256 self.ints.append((a.MulAttributeAddShift,
257 shift))
258
259 def ArithmeticRightShiftAttribute(self, round):
260 from tosa import ArithmeticRightShiftAttribute as a, Attribute
261
262 self.utype = Attribute.Attribute().ArithmeticRightShiftAttribute
263 self.optFcns = (a.ArithmeticRightShiftAttributeStart, a.ArithmeticRightShiftAttributeEnd)
264
265 self.bools.append((a.ArithmeticRightShiftAttributeAddRound,
266 round))
267
Eric Kunzee5e26762020-10-13 16:11:07 -0700268 def CustomAttribute(self, identifier):
269 from tosa import CustomAttribute as a, Attribute
270
271 self.utype = Attribute.Attribute().CustomAttribute
272 self.optFcns = (a.CustomAttributeStart, a.CustomAttributeEnd)
273
274 self.strings.append((a.CustomAttributeAddIdentifier,
275 identifier))
276
277 def CondIfAttribute(self, then_branch, else_branch):
278 from tosa import CondIfAttribute as a, Attribute
279
280 self.utype = Attribute.Attribute().CondIfAttribute
281 self.optFcns = (a.CondIfAttributeStart, a.CondIfAttributeEnd)
282
283 self.strings.append((a.CondIfAttributeAddThenBranch,
284 then_branch))
285 self.strings.append((a.CondIfAttributeAddElseBranch,
286 else_branch))
287
288 def WhileLoopAttribute(self, cond_branch, body_branch):
289 from tosa import WhileLoopAttribute as a, Attribute
290
291 self.utype = Attribute.Attribute().WhileLoopAttribute
292 self.optFcns = (a.WhileLoopAttributeStart, a.WhileLoopAttributeEnd)
293
294 self.strings.append((a.WhileLoopAttributeAddCondBranch,
295 cond_branch))
296 self.strings.append((a.WhileLoopAttributeAddBodyBranch,
297 body_branch))
298
299class TosaSerializerQuantInfo(TosaSerializerUnion):
300 '''This class handles encapsulating all of the enumerated types for quantinfo types'''
301 def __init__(self):
302 super().__init__()
303
304 def ConvQuantInfo(self, input_zp, weight_zp):
305 from tosa import ConvQuantInfo as q, QuantInfo
306
307 self.utype = QuantInfo.QuantInfo().ConvQuantInfo
308 self.optFcns = (q.ConvQuantInfoStart, q.ConvQuantInfoEnd)
309 self.ints.append((q.ConvQuantInfoAddInputZp, input_zp))
310 self.ints.append((q.ConvQuantInfoAddWeightZp, weight_zp))
311
312 def UnaryQuantInfo(self, input_zp, output_zp):
313 from tosa import UnaryQuantInfo as q, QuantInfo
314
315 self.utype = QuantInfo.QuantInfo().UnaryQuantInfo
316 self.optFcns = (q.UnaryQuantInfoStart, q.UnaryQuantInfoEnd)
317 self.ints.append((q.UnaryQuantInfoAddInputZp, input_zp))
318 self.ints.append((q.UnaryQuantInfoAddOutputZp, output_zp))
319
320 def MatMulQuantInfo(self, a_zp, b_zp):
321 from tosa import MatMulQuantInfo as q, QuantInfo
322
323 self.utype = QuantInfo.QuantInfo().MatMulQuantInfo
324 self.optFcns = (q.MatMulQuantInfoStart, q.MatMulQuantInfoEnd)
325 self.ints.append((q.MatMulQuantInfoAddAZp, a_zp))
326 self.ints.append((q.MatMulQuantInfoAddBZp, b_zp))
327
328 def PadQuantInfo(self, input_zp):
329 from tosa import PadQuantInfo as q, QuantInfo
330
331 self.utype = QuantInfo.QuantInfo().PadQuantInfo
332 self.optFcns = (q.PadQuantInfoStart, q.PadQuantInfoEnd)
333 self.ints.append((q.PadQuantInfoAddInputZp, input_zp))
334
335class TosaSerializerTensor:
336 def __init__(self, name, shape, dtype, usage, dformat, filename = None, placeholderFilename = None):
337 self.name = name
338
339 if isinstance(shape, np.ndarray):
340 shape = shape.astype(int).tolist()
341 shape = list(map(int, shape))
342
343 self.shape = shape
344 self.dtype = dtype
345 self.usage = TosaSerializer.toList(usage)
346 self.dformat = TosaSerializer.toList(dformat)
347
348 # Filename for const tensors. This gets written to the .tosa serialization
349 self.filename = filename
350
351 # Filename for placeholder tensors. These get generated by the test generation
352 # process and are written to disk, but are considered input tensors by the network
353 # so they do not appear in the TOSA serialiazation. However, if we want to form a unit
354 # test around these input tensors, we can get the filename from here.
355 self.placeholderFilename = placeholderFilename
356
357 def __str__(self):
358 str = 'TosaSerializerTensor name: {} shape: {} dtype: {} Usage: {} format {} filename: {}'.format(
359 self.name, self.shape, DTypeNames[self.dtype], self.usage, self.dformat, self.filename)
360 return str
361
362 def addUsage(self, usage):
363 self.usage.append(usage)
364
365 def addFormat(self, format):
366 self.dformat.append(format)
367
368 def setDtype(self, dtype):
369 self.dtype = dtype
370
371 def merge(self, name, shape, dtype, usage, dformat, filename = None):
372 # Merge in additional usage/formats to the list
373 found = 0
374 for i in self.usage:
375 if i == usage:
376 found = 1
377 break
378 if not found:
379 self.usage.append(usage)
380
381 found = 0
382 for i in self.dformat:
383 if i == dformat:
384 found = 1
385 break
386 if not found:
387 self.dformat.append(dformat)
388
389 def serialize(self, builder):
390 fb_name = builder.CreateString(self.name)
391 if self.filename:
392 fb_filename = builder.CreateString(self.filename)
393 fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
394 fb_usage = TosaSerializer.serializeInt32Vec(builder, self.usage)
395 fb_dformat = TosaSerializer.serializeInt32Vec(builder, self.dformat)
396
397 TosaTensor.TosaTensorStart(builder)
398 TosaTensor.TosaTensorAddName(builder, fb_name)
399 TosaTensor.TosaTensorAddShape(builder, fb_shapes)
400 TosaTensor.TosaTensorAddType(builder, self.dtype)
401 TosaTensor.TosaTensorAddUsage(builder, fb_usage)
402 TosaTensor.TosaTensorAddFormat(builder, fb_dformat)
403 if self.filename:
404 TosaTensor.TosaTensorAddNpyFilename(builder, fb_filename)
405
406 return TosaTensor.TosaTensorEnd(builder)
407
408class TosaSerializerOperator:
409 def __init__(self, op, inputs, outputs, attributes = None, quantInfo = None):
410 self.op = op
411 self.attributes = attributes
412 self.inputs = TosaSerializer.toList(inputs)
413 self.outputs = TosaSerializer.toList(outputs)
414 self.quantInfo = quantInfo
415
416 def __str__(self):
417 str = 'Op {}\n----\n'.format(self.op)
418
419 for i in self.inputs:
420 str = str + ' Input: {}\n'.format(i)
421 for o in self.outputs:
422 str = str + ' Output: {}\n'.format(o)
423
424 return str
425
426 def serialize(self, builder):
427 fb_inputs = TosaSerializer.serializeStrVec(builder, self.inputs, TosaOperator.TosaOperatorStartInputsVector)
428 fb_outputs = TosaSerializer.serializeStrVec(builder, self.outputs, TosaOperator.TosaOperatorStartOutputsVector)
429 # Need to serialize quant_info and attributes enums still
430 if self.attributes is not None:
431 fb_attributes = self.attributes.serialize(builder)
432
433 if self.quantInfo is not None:
434 fb_qinfo = self.quantInfo.serialize(builder)
435
436 TosaOperator.TosaOperatorStart(builder)
437 TosaOperator.TosaOperatorAddOp(builder, self.op)
438 TosaOperator.TosaOperatorAddInputs(builder, fb_inputs)
439 TosaOperator.TosaOperatorAddOutputs(builder, fb_outputs)
440 if self.attributes is not None:
441 TosaOperator.TosaOperatorAddAttributeType(builder, self.attributes.utype)
442 TosaOperator.TosaOperatorAddAttribute(builder, fb_attributes)
443 if self.quantInfo is not None:
444 TosaOperator.TosaOperatorAddQuantInfoType(builder, self.quantInfo.utype)
445 TosaOperator.TosaOperatorAddQuantInfo(builder, fb_qinfo)
446
447 return TosaOperator.TosaOperatorEnd(builder)
448
449class TosaSerializerBasicBlock:
450 def __init__(self, name):
451 self.name = name
452 self.operators = []
453
454 # Dict assures uniqueness, but allows us to look up by name
455 self.tensors = dict()
456
457 self.inputs = []
458 self.outputs = []
459
460 def addTensor(self, name, shape, dtype, usage, dformat, filename = None, placeholderFilename = None):
461 try:
462 # Someone already added this tensor.
463 # We may have to add more usages and formats
464 tens = self.tensors[name]
465 filename = tens.merge(name, shape, dtype, usage, dformat, filename)
466 except KeyError:
467 self.tensors[name] = TosaSerializerTensor(name, shape, dtype, usage, dformat, filename, placeholderFilename)
468
469 return self.tensors[name]
470
471 def addInput(self, name):
472 self.inputs.append(name)
473
474 def addOutput(self, name):
475 self.outputs.append(name)
476
477 def addOperator(self, op, inputs, outputs, attributes = None, quant_info = None):
478 self.operators.append(TosaSerializerOperator(op, inputs, outputs, attributes, quant_info))
479
480 def serialize(self, builder):
481 fb_name = builder.CreateString(self.name)
482 fbv_inputs = TosaSerializer.serializeStrVec(builder, list(self.inputs), TosaBasicBlock.TosaBasicBlockStartInputsVector)
483 fbv_outputs = TosaSerializer.serializeStrVec(builder, list(self.outputs), TosaBasicBlock.TosaBasicBlockStartOutputsVector)
484 fbv_tensors = TosaSerializer.serializeObjVec(builder, list(self.tensors.values()), TosaBasicBlock.TosaBasicBlockStartTensorsVector)
485 fbv_operators = TosaSerializer.serializeObjVec(builder, self.operators, TosaBasicBlock.TosaBasicBlockStartOperatorsVector)
486
487 TosaBasicBlock.TosaBasicBlockStart(builder)
488 TosaBasicBlock.TosaBasicBlockAddName(builder, fb_name)
489 TosaBasicBlock.TosaBasicBlockAddInputs(builder, fbv_inputs)
490 TosaBasicBlock.TosaBasicBlockAddOutputs(builder, fbv_outputs)
491 TosaBasicBlock.TosaBasicBlockAddTensors(builder, fbv_tensors)
492 TosaBasicBlock.TosaBasicBlockAddOperators(builder, fbv_operators)
493 return TosaBasicBlock.TosaBasicBlockEnd(builder)
494
495@unique
496class TensorDir(IntEnum):
497 PLACEHOLDER = 0
498 CONST = 1
499 INTERMEDIATE = 2
500 RESULT = 3
501
502class TosaSerializer:
503 def __init__(self, pathPrefix):
504
505 # Get the global TOSA version if not already defined
506 try:
507 TOSA_VERSION
508 except NameError:
509 TosaSerializer.setTosaVersion()
510
511 self.builder = flatbuffers.Builder(0)
512
513 self.basicBlocks = []
514 self.startBasicBlock('main')
515 self.pathPrefix = pathPrefix
516
517 # Indicies used for adding/naming tensors
518 self.currInputIdx = 0
519 self.currConstIdx = 0
520 self.currLayerIdx = 1
521 self.currResultIdx = 0
522
523 # Is this an illegal test that is expected to fail?
524 self.expectedFailure = False
525 self.expectedFailureDesc = ''
526
527 def __str__(self):
528 str = ''
529 for bb in self.basicBlocks:
530 str = str + bb.__str__()
531 return str
532
533 def addPlaceholder(self, shape, dtype, usage, dformat, vals):
534 if not self.currBasicBlock:
535 raise Exception('addTensor called without valid basic block')
536
537 name = 'input-{}'.format(self.currInputIdx)
538 filename = '{}.npy'.format(name)
539 self.currInputIdx = self.currInputIdx + 1
540
541 tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, None, filename)
542 # This is always an input to the block
543 self.currBasicBlock.addInput(name)
544 # Add the operator now
545 self.currBasicBlock.addOperator(tosa.Op.Op().PLACEHOLDER, [], name)
546
547 if vals is not None:
548 np.save(os.path.join(self.pathPrefix, filename), vals, False)
549
550 return tens
551
552 def addConst(self, shape, dtype, usage, dformat, vals):
553 if not self.currBasicBlock:
554 raise Exception('addTensor called without valid basic block')
555
556 name = 'const-{}'.format(self.currInputIdx)
557 filename = '{}.npy'.format(name)
558 self.currInputIdx = self.currInputIdx + 1
559
560 tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, filename)
561 # Add the operator now
562 self.currBasicBlock.addOperator(tosa.Op.Op().CONST, [], name)
563
564 if vals is not None:
565 np.save(os.path.join(self.pathPrefix, filename), vals, False)
566 return tens
567
568 def addIntermediate(self, shape, dtype, usage, dformat):
569
570 if not self.currBasicBlock:
571 raise Exception('addTensor called without valid basic block')
572
573 name = 'layer-{}'.format(self.currLayerIdx)
574 filename = None # No file, so no filename
575 self.currLayerIdx = self.currLayerIdx + 1
576
577 tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, filename)
578
579 return tens
580
581 def addInputTensor(self, tensor):
582 self.currBasicBlock.addOperator(tosa.Op.Op().PLACEHOLDER, [], tensor.name)
583 self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype, tensor.usage, tensor.dformat)
584 self.currBasicBlock.addInput(tensor.name)
585
586 def addOutputTensor(self, tensor):
587 self.currBasicBlock.addOutput(tensor.name)
588
589 def addOutput(self, shape, dtype, usage, dformat):
590 if not self.currBasicBlock:
591 raise Exception('addTensor called without valid basic block')
592
593 name = 'result-{}'.format(self.currResultIdx)
594 self.currResultIdx = self.currResultIdx + 1
595
596 tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, None)
597 self.currBasicBlock.addOutput(name)
598 return tens
599
600 def addOperator(self, op, inputs, outputs, attributes = None, quant_info = None):
601
602 if op == tosa.Op.Op().PLACEHOLDER or \
603 op == tosa.Op.Op().CONST:
604 raise Exception('Use addPlaceholderTensor() or addConstTensor() to add PLACEHOLDER and CONST ops')
605
606 return self.currBasicBlock.addOperator(op, inputs, outputs, attributes, quant_info)
607
608 def setExpectedFailure(self, desc='', val=True):
609 self.expectedFailure = val
610 self.expectedFailureDesc = desc
611
612 def setExpectedFailure(self, desc='', val=True):
613 self.expectedFailure = val
614 self.expectedFailureDesc = desc
615
616 def serialize(self):
617
618 builder = self.builder
619
620 Version.VersionStart(builder)
621 Version.VersionAdd_major(builder, TOSA_VERSION[0])
622 Version.VersionAdd_minor(builder, TOSA_VERSION[1])
623 Version.VersionAdd_patch(builder, TOSA_VERSION[2])
624 Version.VersionAdd_experimental(builder, TOSA_VERSION[3])
625 version = Version.VersionEnd(builder)
626
627 fbv_bb = TosaSerializer.serializeObjVec(builder, self.basicBlocks, TosaGraph.TosaGraphStartBlocksVector)
628
629 TosaGraph.TosaGraphStart(builder)
630 TosaGraph.TosaGraphAddVersion(builder, version)
631 TosaGraph.TosaGraphAddBlocks(builder, fbv_bb)
632 graph = TosaGraph.TosaGraphEnd(builder)
633
634 self.builder.Finish(graph)
635 return self.builder.Output()
636
637 def writeJson(self, tosa_filename):
638 '''Write a json test file so that it is fairly easy to pick up the test
639 and generate commands for third party tool'''
640 test_desc = dict()
641
642 test_desc['tosa_file'] = tosa_filename
643 ifm_name = []
644 ifm_shape = []
645 ifm_file = []
646 ofm_name = []
647 ofm_file = []
648 ofm_shape = []
649
650 for b in self.basicBlocks:
651 if b.name == 'main':
652 for i in b.inputs:
653 ifm_name.append(i)
654 ifm_shape.append(b.tensors[i].shape)
655 ifm_file.append(b.tensors[i].placeholderFilename)
656 for o in b.outputs:
657 ofm_name.append(o)
658 ofm_shape.append(b.tensors[o].shape)
659 # Make up an OFM filename here. One isn't generated until the reference tool is
660 # run, so any name is a good name
661 ofm_file.append('ref-{}.npy'.format(o))
662
663 test_desc['ifm_placeholder'] = ifm_name
664 test_desc['ifm_file'] = ifm_file
665 test_desc['ifm_shape'] = ifm_shape
666 test_desc['ofm_name'] = ofm_name
667 test_desc['ofm_shape'] = ofm_shape
668 test_desc['ofm_file'] = ofm_file
669 test_desc['expected_failure'] = self.expectedFailure
670 if self.expectedFailureDesc:
671 test_desc['expected_failure_desc'] = self.expectedFailureDesc
672
673 return json.dumps(test_desc, indent=' ')
674
675 def startBasicBlock(self, name):
676 self.currBasicBlock = TosaSerializerBasicBlock(name)
677 self.basicBlocks.append(self.currBasicBlock)
678
679 @staticmethod
680 def serializeStrVec(builder, vec, start_fcn):
681 fb_strs = [builder.CreateString(i) for i in vec]
682 start_fcn(builder, len(fb_strs))
683 for s in fb_strs[::-1]:
684 builder.PrependUOffsetTRelative(s)
685 return builder.EndVector(len(fb_strs))
686
687 @staticmethod
688 def serializeInt32Vec(builder, vec):
689 builder.StartVector(4, len(vec), 4)
690 for v in vec[::-1]:
691 builder.PrependInt32(v)
692 return builder.EndVector(len(vec))
693
694 @staticmethod
695 def serializeObjVec(builder, vec, start_fcn):
696 serialized_vec = []
697 for v in vec[::-1]:
698 serialized_vec.append(v.serialize(builder))
699
700 start_fcn(builder, len(vec))
701 for v in serialized_vec:
702 builder.PrependUOffsetTRelative(v)
703 return builder.EndVector(len(vec))
704
705 @staticmethod
706 def toList(val):
707 if isinstance(val, list):
708 return val
709 else:
710 return [val]
711
712 @staticmethod
713 def setTosaVersion():
714 # Create a dummy flatbuffers file with the default version information
715 # There does not appear to be a better way to get a constant from a
716 # flatbuffer schema file
717 builder = flatbuffers.Builder(0)
718 Version.VersionStart(builder)
719 ver = Version.VersionEnd(builder)
720 TosaGraph.TosaGraphStart(builder)
721 TosaGraph.TosaGraphAddVersion(builder, ver)
722 gr = TosaGraph.TosaGraphEnd(builder)
723 builder.Finish(gr)
724
725 out = builder.Output()
726
727 gr = TosaGraph.TosaGraph()
728 root = gr.GetRootAsTosaGraph(out, 0)
729
730 # Store the version as a global variable so that it only needs to be
731 # generated once per process.
732 global TOSA_VERSION
733 TOSA_VERSION = [root.Version()._major(),
734 root.Version()._minor(),
735 root.Version()._patch(),
736 root.Version()._experimental() ]