blob: 7ba68c3715a86ef1d02ae957ced2e818432c79e4 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2
3# Copyright (c) 2020, ARM Limited.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17#!/usr/bin/env python3
18
19import flatbuffers
20import numpy as np
21from enum import Enum, IntEnum, unique
22from tosa import TosaGraph, TosaBasicBlock, TosaTensor, TosaOperator, DType, Format, Usage, Op, ResizeMode, Version
23import tosa
24import os
25import json
26
27# With the way flatc generates its python types, there is no programatic way
28# to get string names for the integer types. Manually maintain a string table
29# here.
30DTypeNames = [ 'UNKNOWN',
31 'BOOL',
32 'AINT8',
33 'UINT8',
34 'INT4',
35 'INT8',
36 'INT16',
37 'INT32',
38 'INT48',
39 'FLOAT' ]
40
41def dtype_str_to_val(name):
42
43 for i in range(len(DTypeNames)):
44 if name.casefold() == DTypeNames[i].casefold():
45 return i
46 raise Exception('Unable to parse DType name {}'.format(name))
47
48
49class TosaSerializerUnion:
50 '''This class handles encapsulating and serializing union types into flatbuffers'''
51 def __init__(self):
52
53 # A tuple of the start and end functions. Set by the options constructors below
54 self.optFcns = None
55
56 # The type from the tosa.Options enumeration. Set by the options constructors below.
57 self.utype = None
58
59 # Each of these lists is a tuple of the add function and the
60 # value being added. Set by the options constructors below.
61 self.ints = []
62 self.bools = []
63 self.floats = []
64 self.strings = []
65 self.intvecs = []
66
67 def serialize(self, builder):
68
69 # We have to build strings and vectors first
70 strList = []
71 intVecList = []
72
73 for fcn, val in self.strings:
74 strList.append((fcn, builder.CreateString(val)))
75
76 for fcn, val in self.intvecs:
77 intVecList.append((fcn, TosaSerializer.serializeInt32Vec(builder, val)))
78
79 startFcn, endFcn = self.optFcns
80
81 # Then serialize the options object from the list of primitives and
82 # other serialized values
83 startFcn(builder)
84 for fcn, val in self.ints:
85 fcn(builder, val)
86
87 for fcn, val in self.bools:
88 fcn(builder, val)
89
90 for fcn, val in self.floats:
91 fcn(builder, val)
92
93 for fcn, val in strList:
94 fcn(builder, val)
95
96 for fcn, val in intVecList:
97 fcn(builder, val)
98
99 return endFcn(builder)
100
101class TosaSerializerAttribute(TosaSerializerUnion):
102 '''This class handles encapsulating all of the enumerated types for attributes'''
103
104 def __init__(self):
105 super().__init__()
106
107 def Pool2dAttribute(self, kernel, stride, padding):
108 from tosa import Pool2dAttribute as a, Attribute
109
110 self.utype = Attribute.Attribute().Pool2dAttribute
111
112 self.optFcns = (a.Pool2dAttributeStart, a.Pool2dAttributeEnd)
113 self.intvecs.append((a.Pool2dAttributeAddPadding,
114 padding))
115 self.intvecs.append((a.Pool2dAttributeAddKernel,
116 kernel))
117 self.intvecs.append((a.Pool2dAttributeAddStride,
118 stride))
119
120 def Conv2dAttribute(self, padding, stride, dilation):
121 from tosa import Conv2dAttribute as a, Attribute
122
123 self.utype = Attribute.Attribute().Conv2dAttribute
124 self.optFcns = (a.Conv2dAttributeStart, a.Conv2dAttributeEnd)
125
126 self.intvecs.append((a.Conv2dAttributeAddPadding,
127 padding))
128 self.intvecs.append((a.Conv2dAttributeAddStride,
129 stride))
130 self.intvecs.append((a.Conv2dAttributeAddDilation,
131 dilation))
132
133 def TransposeConv2DAttribute(self, outpad, stride, dilation, output_shape):
134 from tosa import TransposeConv2dAttribute as a, Attribute
135
136 self.utype = Attribute.Attribute().TransposeConv2dAttribute
137 self.optFcns = (a.TransposeConv2dAttributeStart, a.TransposeConv2dAttributeEnd)
138
139 self.intvecs.append((a.TransposeConv2dAttributeAddOutpad,
140 outpad))
141 self.intvecs.append((a.TransposeConv2dAttributeAddStride,
142 stride))
143 self.intvecs.append((a.TransposeConv2dAttributeAddDilation,
144 dilation))
145 self.intvecs.append((a.TransposeConv2dAttributeAddOutputShape,
146 output_shape))
147
148 def ReluNAttribute(self, maxint, maxfp):
149 from tosa import ReluNAttribute as a, Attribute
150
151 self.utype = Attribute.Attribute().ReluNAttribute
152 self.optFcns = (a.ReluNAttributeStart, a.ReluNAttributeEnd)
153
154 self.ints.append((a.ReluNAttributeAddMaxInt, maxint))
155 self.ints.append((a.ReluNAttributeAddMaxFp, maxfp))
156
157
158 def AxisAttribute(self, axis):
159 from tosa import AxisAttribute as a, Attribute
160
161 self.utype = Attribute.Attribute().AxisAttribute
162 self.optFcns = (a.AxisAttributeStart, a.AxisAttributeEnd)
163
164 self.ints.append((a.AxisAttributeAddAxis,
165 axis))
166
167 def ReshapeAttribute(self, shape):
168 from tosa import ReshapeAttribute as a, Attribute
169
170 self.utype = Attribute.Attribute().ReshapeAttribute
171 self.optFcns = (a.ReshapeAttributeStart, a.ReshapeAttributeEnd)
172
173 self.intvecs.append((a.ReshapeAttributeAddShape,
174 shape))
175
176 def SliceAttribute(self, begin, size):
177 from tosa import SliceAttribute as a, Attribute
178
179 self.utype = Attribute.Attribute().SliceAttribute
180 self.optFcns = (a.SliceAttributeStart, a.SliceAttributeEnd)
181
182 self.intvecs.append((a.SliceAttributeAddBegin,
183 begin))
184 self.intvecs.append((a.SliceAttributeAddSize,
185 size))
186
187 def TileAttribute(self, multiples):
188 from tosa import TileAttribute as a, Attribute
189
190 self.utype = Attribute.Attribute().TileAttribute
191 self.optFcns = (a.TileAttributeStart, a.TileAttributeEnd)
192
193 self.intvecs.append((a.TileAttributeAddMultiples,
194 multiples))
195
196 def ResizeAttribute(self, output_size, stride, offset, shift, mode):
197 from tosa import ResizeAttribute as a, Attribute
198
199 self.utype = Attribute.Attribute().ResizeAttribute
200 self.optFcns = (a.ResizeAttributeStart, a.ResizeAttributeEnd)
201
202 self.intvecs.append((a.ResizeAttributeAddOutputSize,
203 output_size))
204 self.intvecs.append((a.ResizeAttributeAddStride,
205 stride))
206 self.intvecs.append((a.ResizeAttributeAddOffset,
207 offset))
208 self.ints.append((a.ResizeAttributeAddShift,
209 shift))
210 self.ints.append((a.ResizeAttributeAddMode,
211 mode))
212
213 def ClampAttribute(self, minint, maxint, minfp, maxfp):
214 from tosa import ClampAttribute as a, Attribute
215
216 self.utype = Attribute.Attribute().ClampAttribute
217 self.optFcns = (a.ClampAttributeStart, a.ClampAttributeEnd)
218
219 self.ints.append((a.ClampAttributeAddMinInt,
220 minint))
221 self.ints.append((a.ClampAttributeAddMaxInt,
222 maxint))
223
224 self.ints.append((a.ClampAttributeAddMinFp,
225 minfp))
226 self.ints.append((a.ClampAttributeAddMaxFp,
227 maxfp))
228
229 def RescaleAttribute(self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel):
230 from tosa import RescaleAttribute as a, Attribute
231
232 self.utype = Attribute.Attribute().RescaleAttribute
233 self.optFcns = (a.RescaleAttributeStart, a.RescaleAttributeEnd)
234
235 self.ints.append((a.RescaleAttributeAddInputZp,
236 input_zp))
237 self.ints.append((a.RescaleAttributeAddOutputZp,
238 output_zp))
239 self.intvecs.append((a.RescaleAttributeAddMultiplier,
240 multiplier))
241 self.intvecs.append((a.RescaleAttributeAddShift,
242 shift))
243 self.bools.append((a.RescaleAttributeAddScale32,
244 scale32))
245 self.bools.append((a.RescaleAttributeAddDoubleRound,
246 double_round))
247 self.bools.append((a.RescaleAttributeAddPerChannel,
248 per_channel))
249
250 def CustomAttribute(self, identifier):
251 from tosa import CustomAttribute as a, Attribute
252
253 self.utype = Attribute.Attribute().CustomAttribute
254 self.optFcns = (a.CustomAttributeStart, a.CustomAttributeEnd)
255
256 self.strings.append((a.CustomAttributeAddIdentifier,
257 identifier))
258
259 def CondIfAttribute(self, then_branch, else_branch):
260 from tosa import CondIfAttribute as a, Attribute
261
262 self.utype = Attribute.Attribute().CondIfAttribute
263 self.optFcns = (a.CondIfAttributeStart, a.CondIfAttributeEnd)
264
265 self.strings.append((a.CondIfAttributeAddThenBranch,
266 then_branch))
267 self.strings.append((a.CondIfAttributeAddElseBranch,
268 else_branch))
269
270 def WhileLoopAttribute(self, cond_branch, body_branch):
271 from tosa import WhileLoopAttribute as a, Attribute
272
273 self.utype = Attribute.Attribute().WhileLoopAttribute
274 self.optFcns = (a.WhileLoopAttributeStart, a.WhileLoopAttributeEnd)
275
276 self.strings.append((a.WhileLoopAttributeAddCondBranch,
277 cond_branch))
278 self.strings.append((a.WhileLoopAttributeAddBodyBranch,
279 body_branch))
280
281class TosaSerializerQuantInfo(TosaSerializerUnion):
282 '''This class handles encapsulating all of the enumerated types for quantinfo types'''
283 def __init__(self):
284 super().__init__()
285
286 def ConvQuantInfo(self, input_zp, weight_zp):
287 from tosa import ConvQuantInfo as q, QuantInfo
288
289 self.utype = QuantInfo.QuantInfo().ConvQuantInfo
290 self.optFcns = (q.ConvQuantInfoStart, q.ConvQuantInfoEnd)
291 self.ints.append((q.ConvQuantInfoAddInputZp, input_zp))
292 self.ints.append((q.ConvQuantInfoAddWeightZp, weight_zp))
293
294 def UnaryQuantInfo(self, input_zp, output_zp):
295 from tosa import UnaryQuantInfo as q, QuantInfo
296
297 self.utype = QuantInfo.QuantInfo().UnaryQuantInfo
298 self.optFcns = (q.UnaryQuantInfoStart, q.UnaryQuantInfoEnd)
299 self.ints.append((q.UnaryQuantInfoAddInputZp, input_zp))
300 self.ints.append((q.UnaryQuantInfoAddOutputZp, output_zp))
301
302 def MatMulQuantInfo(self, a_zp, b_zp):
303 from tosa import MatMulQuantInfo as q, QuantInfo
304
305 self.utype = QuantInfo.QuantInfo().MatMulQuantInfo
306 self.optFcns = (q.MatMulQuantInfoStart, q.MatMulQuantInfoEnd)
307 self.ints.append((q.MatMulQuantInfoAddAZp, a_zp))
308 self.ints.append((q.MatMulQuantInfoAddBZp, b_zp))
309
310 def PadQuantInfo(self, input_zp):
311 from tosa import PadQuantInfo as q, QuantInfo
312
313 self.utype = QuantInfo.QuantInfo().PadQuantInfo
314 self.optFcns = (q.PadQuantInfoStart, q.PadQuantInfoEnd)
315 self.ints.append((q.PadQuantInfoAddInputZp, input_zp))
316
317class TosaSerializerTensor:
318 def __init__(self, name, shape, dtype, usage, dformat, filename = None, placeholderFilename = None):
319 self.name = name
320
321 if isinstance(shape, np.ndarray):
322 shape = shape.astype(int).tolist()
323 shape = list(map(int, shape))
324
325 self.shape = shape
326 self.dtype = dtype
327 self.usage = TosaSerializer.toList(usage)
328 self.dformat = TosaSerializer.toList(dformat)
329
330 # Filename for const tensors. This gets written to the .tosa serialization
331 self.filename = filename
332
333 # Filename for placeholder tensors. These get generated by the test generation
334 # process and are written to disk, but are considered input tensors by the network
335 # so they do not appear in the TOSA serialiazation. However, if we want to form a unit
336 # test around these input tensors, we can get the filename from here.
337 self.placeholderFilename = placeholderFilename
338
339 def __str__(self):
340 str = 'TosaSerializerTensor name: {} shape: {} dtype: {} Usage: {} format {} filename: {}'.format(
341 self.name, self.shape, DTypeNames[self.dtype], self.usage, self.dformat, self.filename)
342 return str
343
344 def addUsage(self, usage):
345 self.usage.append(usage)
346
347 def addFormat(self, format):
348 self.dformat.append(format)
349
350 def setDtype(self, dtype):
351 self.dtype = dtype
352
353 def merge(self, name, shape, dtype, usage, dformat, filename = None):
354 # Merge in additional usage/formats to the list
355 found = 0
356 for i in self.usage:
357 if i == usage:
358 found = 1
359 break
360 if not found:
361 self.usage.append(usage)
362
363 found = 0
364 for i in self.dformat:
365 if i == dformat:
366 found = 1
367 break
368 if not found:
369 self.dformat.append(dformat)
370
371 def serialize(self, builder):
372 fb_name = builder.CreateString(self.name)
373 if self.filename:
374 fb_filename = builder.CreateString(self.filename)
375 fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
376 fb_usage = TosaSerializer.serializeInt32Vec(builder, self.usage)
377 fb_dformat = TosaSerializer.serializeInt32Vec(builder, self.dformat)
378
379 TosaTensor.TosaTensorStart(builder)
380 TosaTensor.TosaTensorAddName(builder, fb_name)
381 TosaTensor.TosaTensorAddShape(builder, fb_shapes)
382 TosaTensor.TosaTensorAddType(builder, self.dtype)
383 TosaTensor.TosaTensorAddUsage(builder, fb_usage)
384 TosaTensor.TosaTensorAddFormat(builder, fb_dformat)
385 if self.filename:
386 TosaTensor.TosaTensorAddNpyFilename(builder, fb_filename)
387
388 return TosaTensor.TosaTensorEnd(builder)
389
390class TosaSerializerOperator:
391 def __init__(self, op, inputs, outputs, attributes = None, quantInfo = None):
392 self.op = op
393 self.attributes = attributes
394 self.inputs = TosaSerializer.toList(inputs)
395 self.outputs = TosaSerializer.toList(outputs)
396 self.quantInfo = quantInfo
397
398 def __str__(self):
399 str = 'Op {}\n----\n'.format(self.op)
400
401 for i in self.inputs:
402 str = str + ' Input: {}\n'.format(i)
403 for o in self.outputs:
404 str = str + ' Output: {}\n'.format(o)
405
406 return str
407
408 def serialize(self, builder):
409 fb_inputs = TosaSerializer.serializeStrVec(builder, self.inputs, TosaOperator.TosaOperatorStartInputsVector)
410 fb_outputs = TosaSerializer.serializeStrVec(builder, self.outputs, TosaOperator.TosaOperatorStartOutputsVector)
411 # Need to serialize quant_info and attributes enums still
412 if self.attributes is not None:
413 fb_attributes = self.attributes.serialize(builder)
414
415 if self.quantInfo is not None:
416 fb_qinfo = self.quantInfo.serialize(builder)
417
418 TosaOperator.TosaOperatorStart(builder)
419 TosaOperator.TosaOperatorAddOp(builder, self.op)
420 TosaOperator.TosaOperatorAddInputs(builder, fb_inputs)
421 TosaOperator.TosaOperatorAddOutputs(builder, fb_outputs)
422 if self.attributes is not None:
423 TosaOperator.TosaOperatorAddAttributeType(builder, self.attributes.utype)
424 TosaOperator.TosaOperatorAddAttribute(builder, fb_attributes)
425 if self.quantInfo is not None:
426 TosaOperator.TosaOperatorAddQuantInfoType(builder, self.quantInfo.utype)
427 TosaOperator.TosaOperatorAddQuantInfo(builder, fb_qinfo)
428
429 return TosaOperator.TosaOperatorEnd(builder)
430
431class TosaSerializerBasicBlock:
432 def __init__(self, name):
433 self.name = name
434 self.operators = []
435
436 # Dict assures uniqueness, but allows us to look up by name
437 self.tensors = dict()
438
439 self.inputs = []
440 self.outputs = []
441
442 def addTensor(self, name, shape, dtype, usage, dformat, filename = None, placeholderFilename = None):
443 try:
444 # Someone already added this tensor.
445 # We may have to add more usages and formats
446 tens = self.tensors[name]
447 filename = tens.merge(name, shape, dtype, usage, dformat, filename)
448 except KeyError:
449 self.tensors[name] = TosaSerializerTensor(name, shape, dtype, usage, dformat, filename, placeholderFilename)
450
451 return self.tensors[name]
452
453 def addInput(self, name):
454 self.inputs.append(name)
455
456 def addOutput(self, name):
457 self.outputs.append(name)
458
459 def addOperator(self, op, inputs, outputs, attributes = None, quant_info = None):
460 self.operators.append(TosaSerializerOperator(op, inputs, outputs, attributes, quant_info))
461
462 def serialize(self, builder):
463 fb_name = builder.CreateString(self.name)
464 fbv_inputs = TosaSerializer.serializeStrVec(builder, list(self.inputs), TosaBasicBlock.TosaBasicBlockStartInputsVector)
465 fbv_outputs = TosaSerializer.serializeStrVec(builder, list(self.outputs), TosaBasicBlock.TosaBasicBlockStartOutputsVector)
466 fbv_tensors = TosaSerializer.serializeObjVec(builder, list(self.tensors.values()), TosaBasicBlock.TosaBasicBlockStartTensorsVector)
467 fbv_operators = TosaSerializer.serializeObjVec(builder, self.operators, TosaBasicBlock.TosaBasicBlockStartOperatorsVector)
468
469 TosaBasicBlock.TosaBasicBlockStart(builder)
470 TosaBasicBlock.TosaBasicBlockAddName(builder, fb_name)
471 TosaBasicBlock.TosaBasicBlockAddInputs(builder, fbv_inputs)
472 TosaBasicBlock.TosaBasicBlockAddOutputs(builder, fbv_outputs)
473 TosaBasicBlock.TosaBasicBlockAddTensors(builder, fbv_tensors)
474 TosaBasicBlock.TosaBasicBlockAddOperators(builder, fbv_operators)
475 return TosaBasicBlock.TosaBasicBlockEnd(builder)
476
477@unique
478class TensorDir(IntEnum):
479 PLACEHOLDER = 0
480 CONST = 1
481 INTERMEDIATE = 2
482 RESULT = 3
483
484class TosaSerializer:
485 def __init__(self, pathPrefix):
486
487 # Get the global TOSA version if not already defined
488 try:
489 TOSA_VERSION
490 except NameError:
491 TosaSerializer.setTosaVersion()
492
493 self.builder = flatbuffers.Builder(0)
494
495 self.basicBlocks = []
496 self.startBasicBlock('main')
497 self.pathPrefix = pathPrefix
498
499 # Indicies used for adding/naming tensors
500 self.currInputIdx = 0
501 self.currConstIdx = 0
502 self.currLayerIdx = 1
503 self.currResultIdx = 0
504
505 # Is this an illegal test that is expected to fail?
506 self.expectedFailure = False
507 self.expectedFailureDesc = ''
508
509 def __str__(self):
510 str = ''
511 for bb in self.basicBlocks:
512 str = str + bb.__str__()
513 return str
514
515 def addPlaceholder(self, shape, dtype, usage, dformat, vals):
516 if not self.currBasicBlock:
517 raise Exception('addTensor called without valid basic block')
518
519 name = 'input-{}'.format(self.currInputIdx)
520 filename = '{}.npy'.format(name)
521 self.currInputIdx = self.currInputIdx + 1
522
523 tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, None, filename)
524 # This is always an input to the block
525 self.currBasicBlock.addInput(name)
526 # Add the operator now
527 self.currBasicBlock.addOperator(tosa.Op.Op().PLACEHOLDER, [], name)
528
529 if vals is not None:
530 np.save(os.path.join(self.pathPrefix, filename), vals, False)
531
532 return tens
533
534 def addConst(self, shape, dtype, usage, dformat, vals):
535 if not self.currBasicBlock:
536 raise Exception('addTensor called without valid basic block')
537
538 name = 'const-{}'.format(self.currInputIdx)
539 filename = '{}.npy'.format(name)
540 self.currInputIdx = self.currInputIdx + 1
541
542 tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, filename)
543 # Add the operator now
544 self.currBasicBlock.addOperator(tosa.Op.Op().CONST, [], name)
545
546 if vals is not None:
547 np.save(os.path.join(self.pathPrefix, filename), vals, False)
548 return tens
549
550 def addIntermediate(self, shape, dtype, usage, dformat):
551
552 if not self.currBasicBlock:
553 raise Exception('addTensor called without valid basic block')
554
555 name = 'layer-{}'.format(self.currLayerIdx)
556 filename = None # No file, so no filename
557 self.currLayerIdx = self.currLayerIdx + 1
558
559 tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, filename)
560
561 return tens
562
563 def addInputTensor(self, tensor):
564 self.currBasicBlock.addOperator(tosa.Op.Op().PLACEHOLDER, [], tensor.name)
565 self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype, tensor.usage, tensor.dformat)
566 self.currBasicBlock.addInput(tensor.name)
567
568 def addOutputTensor(self, tensor):
569 self.currBasicBlock.addOutput(tensor.name)
570
571 def addOutput(self, shape, dtype, usage, dformat):
572 if not self.currBasicBlock:
573 raise Exception('addTensor called without valid basic block')
574
575 name = 'result-{}'.format(self.currResultIdx)
576 self.currResultIdx = self.currResultIdx + 1
577
578 tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, None)
579 self.currBasicBlock.addOutput(name)
580 return tens
581
582 def addOperator(self, op, inputs, outputs, attributes = None, quant_info = None):
583
584 if op == tosa.Op.Op().PLACEHOLDER or \
585 op == tosa.Op.Op().CONST:
586 raise Exception('Use addPlaceholderTensor() or addConstTensor() to add PLACEHOLDER and CONST ops')
587
588 return self.currBasicBlock.addOperator(op, inputs, outputs, attributes, quant_info)
589
590 def setExpectedFailure(self, desc='', val=True):
591 self.expectedFailure = val
592 self.expectedFailureDesc = desc
593
594 def setExpectedFailure(self, desc='', val=True):
595 self.expectedFailure = val
596 self.expectedFailureDesc = desc
597
598 def serialize(self):
599
600 builder = self.builder
601
602 Version.VersionStart(builder)
603 Version.VersionAdd_major(builder, TOSA_VERSION[0])
604 Version.VersionAdd_minor(builder, TOSA_VERSION[1])
605 Version.VersionAdd_patch(builder, TOSA_VERSION[2])
606 Version.VersionAdd_experimental(builder, TOSA_VERSION[3])
607 version = Version.VersionEnd(builder)
608
609 fbv_bb = TosaSerializer.serializeObjVec(builder, self.basicBlocks, TosaGraph.TosaGraphStartBlocksVector)
610
611 TosaGraph.TosaGraphStart(builder)
612 TosaGraph.TosaGraphAddVersion(builder, version)
613 TosaGraph.TosaGraphAddBlocks(builder, fbv_bb)
614 graph = TosaGraph.TosaGraphEnd(builder)
615
616 self.builder.Finish(graph)
617 return self.builder.Output()
618
619 def writeJson(self, tosa_filename):
620 '''Write a json test file so that it is fairly easy to pick up the test
621 and generate commands for third party tool'''
622 test_desc = dict()
623
624 test_desc['tosa_file'] = tosa_filename
625 ifm_name = []
626 ifm_shape = []
627 ifm_file = []
628 ofm_name = []
629 ofm_file = []
630 ofm_shape = []
631
632 for b in self.basicBlocks:
633 if b.name == 'main':
634 for i in b.inputs:
635 ifm_name.append(i)
636 ifm_shape.append(b.tensors[i].shape)
637 ifm_file.append(b.tensors[i].placeholderFilename)
638 for o in b.outputs:
639 ofm_name.append(o)
640 ofm_shape.append(b.tensors[o].shape)
641 # Make up an OFM filename here. One isn't generated until the reference tool is
642 # run, so any name is a good name
643 ofm_file.append('ref-{}.npy'.format(o))
644
645 test_desc['ifm_placeholder'] = ifm_name
646 test_desc['ifm_file'] = ifm_file
647 test_desc['ifm_shape'] = ifm_shape
648 test_desc['ofm_name'] = ofm_name
649 test_desc['ofm_shape'] = ofm_shape
650 test_desc['ofm_file'] = ofm_file
651 test_desc['expected_failure'] = self.expectedFailure
652 if self.expectedFailureDesc:
653 test_desc['expected_failure_desc'] = self.expectedFailureDesc
654
655 return json.dumps(test_desc, indent=' ')
656
657 def startBasicBlock(self, name):
658 self.currBasicBlock = TosaSerializerBasicBlock(name)
659 self.basicBlocks.append(self.currBasicBlock)
660
661 @staticmethod
662 def serializeStrVec(builder, vec, start_fcn):
663 fb_strs = [builder.CreateString(i) for i in vec]
664 start_fcn(builder, len(fb_strs))
665 for s in fb_strs[::-1]:
666 builder.PrependUOffsetTRelative(s)
667 return builder.EndVector(len(fb_strs))
668
669 @staticmethod
670 def serializeInt32Vec(builder, vec):
671 builder.StartVector(4, len(vec), 4)
672 for v in vec[::-1]:
673 builder.PrependInt32(v)
674 return builder.EndVector(len(vec))
675
676 @staticmethod
677 def serializeObjVec(builder, vec, start_fcn):
678 serialized_vec = []
679 for v in vec[::-1]:
680 serialized_vec.append(v.serialize(builder))
681
682 start_fcn(builder, len(vec))
683 for v in serialized_vec:
684 builder.PrependUOffsetTRelative(v)
685 return builder.EndVector(len(vec))
686
687 @staticmethod
688 def toList(val):
689 if isinstance(val, list):
690 return val
691 else:
692 return [val]
693
694 @staticmethod
695 def setTosaVersion():
696 # Create a dummy flatbuffers file with the default version information
697 # There does not appear to be a better way to get a constant from a
698 # flatbuffer schema file
699 builder = flatbuffers.Builder(0)
700 Version.VersionStart(builder)
701 ver = Version.VersionEnd(builder)
702 TosaGraph.TosaGraphStart(builder)
703 TosaGraph.TosaGraphAddVersion(builder, ver)
704 gr = TosaGraph.TosaGraphEnd(builder)
705 builder.Finish(gr)
706
707 out = builder.Output()
708
709 gr = TosaGraph.TosaGraph()
710 root = gr.GetRootAsTosaGraph(out, 0)
711
712 # Store the version as a global variable so that it only needs to be
713 # generated once per process.
714 global TOSA_VERSION
715 TOSA_VERSION = [root.Version()._major(),
716 root.Version()._minor(),
717 root.Version()._patch(),
718 root.Version()._experimental() ]