blob: 3b7e339e67798bd216b51908b07f89f357831f8c [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2
3# Copyright (c) 2020, ARM Limited.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17#!/usr/bin/env python3
18
19import flatbuffers
20import numpy as np
21from enum import Enum, IntEnum, unique
22from tosa import TosaGraph, TosaBasicBlock, TosaTensor, TosaOperator, DType, Format, Usage, Op, ResizeMode, Version
23import tosa
24import os
25import json
26
27# With the way flatc generates its python types, there is no programatic way
28# to get string names for the integer types. Manually maintain a string table
29# here.
30DTypeNames = [ 'UNKNOWN',
31 'BOOL',
32 'AINT8',
33 'UINT8',
34 'INT4',
35 'INT8',
36 'INT16',
37 'INT32',
38 'INT48',
39 'FLOAT' ]
40
41def dtype_str_to_val(name):
42
43 for i in range(len(DTypeNames)):
44 if name.casefold() == DTypeNames[i].casefold():
45 return i
46 raise Exception('Unable to parse DType name {}'.format(name))
47
48
49class TosaSerializerUnion:
50 '''This class handles encapsulating and serializing union types into flatbuffers'''
51 def __init__(self):
52
53 # A tuple of the start and end functions. Set by the options constructors below
54 self.optFcns = None
55
56 # The type from the tosa.Options enumeration. Set by the options constructors below.
57 self.utype = None
58
59 # Each of these lists is a tuple of the add function and the
60 # value being added. Set by the options constructors below.
61 self.ints = []
62 self.bools = []
63 self.floats = []
64 self.strings = []
65 self.intvecs = []
Kevin Cheng77d0f762020-11-24 10:26:32 -080066 self.fpvecs = []
Eric Kunzee5e26762020-10-13 16:11:07 -070067
68 def serialize(self, builder):
69
70 # We have to build strings and vectors first
71 strList = []
72 intVecList = []
Kevin Cheng77d0f762020-11-24 10:26:32 -080073 fpVecList = []
Eric Kunzee5e26762020-10-13 16:11:07 -070074
75 for fcn, val in self.strings:
76 strList.append((fcn, builder.CreateString(val)))
77
78 for fcn, val in self.intvecs:
79 intVecList.append((fcn, TosaSerializer.serializeInt32Vec(builder, val)))
80
Kevin Cheng77d0f762020-11-24 10:26:32 -080081 for fcn, val in self.fpvecs:
82 fpVecList.append((fcn, TosaSerializer.serializeFpVec(builder, val)))
83
Eric Kunzee5e26762020-10-13 16:11:07 -070084 startFcn, endFcn = self.optFcns
85
86 # Then serialize the options object from the list of primitives and
87 # other serialized values
88 startFcn(builder)
89 for fcn, val in self.ints:
90 fcn(builder, val)
91
92 for fcn, val in self.bools:
93 fcn(builder, val)
94
95 for fcn, val in self.floats:
96 fcn(builder, val)
97
98 for fcn, val in strList:
99 fcn(builder, val)
100
101 for fcn, val in intVecList:
102 fcn(builder, val)
103
Kevin Cheng77d0f762020-11-24 10:26:32 -0800104 for fcn, val in fpVecList:
105 fcn(builder, val)
106
Eric Kunzee5e26762020-10-13 16:11:07 -0700107 return endFcn(builder)
108
109class TosaSerializerAttribute(TosaSerializerUnion):
110 '''This class handles encapsulating all of the enumerated types for attributes'''
111
112 def __init__(self):
113 super().__init__()
114
115 def Pool2dAttribute(self, kernel, stride, padding):
116 from tosa import Pool2dAttribute as a, Attribute
117
118 self.utype = Attribute.Attribute().Pool2dAttribute
119
120 self.optFcns = (a.Pool2dAttributeStart, a.Pool2dAttributeEnd)
121 self.intvecs.append((a.Pool2dAttributeAddPadding,
122 padding))
123 self.intvecs.append((a.Pool2dAttributeAddKernel,
124 kernel))
125 self.intvecs.append((a.Pool2dAttributeAddStride,
126 stride))
127
128 def Conv2dAttribute(self, padding, stride, dilation):
129 from tosa import Conv2dAttribute as a, Attribute
130
131 self.utype = Attribute.Attribute().Conv2dAttribute
132 self.optFcns = (a.Conv2dAttributeStart, a.Conv2dAttributeEnd)
133
134 self.intvecs.append((a.Conv2dAttributeAddPadding,
135 padding))
136 self.intvecs.append((a.Conv2dAttributeAddStride,
137 stride))
138 self.intvecs.append((a.Conv2dAttributeAddDilation,
139 dilation))
140
141 def TransposeConv2DAttribute(self, outpad, stride, dilation, output_shape):
142 from tosa import TransposeConv2dAttribute as a, Attribute
143
144 self.utype = Attribute.Attribute().TransposeConv2dAttribute
145 self.optFcns = (a.TransposeConv2dAttributeStart, a.TransposeConv2dAttributeEnd)
146
147 self.intvecs.append((a.TransposeConv2dAttributeAddOutpad,
148 outpad))
149 self.intvecs.append((a.TransposeConv2dAttributeAddStride,
150 stride))
151 self.intvecs.append((a.TransposeConv2dAttributeAddDilation,
152 dilation))
153 self.intvecs.append((a.TransposeConv2dAttributeAddOutputShape,
154 output_shape))
155
156 def ReluNAttribute(self, maxint, maxfp):
157 from tosa import ReluNAttribute as a, Attribute
158
159 self.utype = Attribute.Attribute().ReluNAttribute
160 self.optFcns = (a.ReluNAttributeStart, a.ReluNAttributeEnd)
161
162 self.ints.append((a.ReluNAttributeAddMaxInt, maxint))
163 self.ints.append((a.ReluNAttributeAddMaxFp, maxfp))
164
165
166 def AxisAttribute(self, axis):
167 from tosa import AxisAttribute as a, Attribute
168
169 self.utype = Attribute.Attribute().AxisAttribute
170 self.optFcns = (a.AxisAttributeStart, a.AxisAttributeEnd)
171
172 self.ints.append((a.AxisAttributeAddAxis,
173 axis))
174
175 def ReshapeAttribute(self, shape):
176 from tosa import ReshapeAttribute as a, Attribute
177
178 self.utype = Attribute.Attribute().ReshapeAttribute
179 self.optFcns = (a.ReshapeAttributeStart, a.ReshapeAttributeEnd)
180
181 self.intvecs.append((a.ReshapeAttributeAddShape,
182 shape))
183
184 def SliceAttribute(self, begin, size):
185 from tosa import SliceAttribute as a, Attribute
186
187 self.utype = Attribute.Attribute().SliceAttribute
188 self.optFcns = (a.SliceAttributeStart, a.SliceAttributeEnd)
189
190 self.intvecs.append((a.SliceAttributeAddBegin,
191 begin))
192 self.intvecs.append((a.SliceAttributeAddSize,
193 size))
194
195 def TileAttribute(self, multiples):
196 from tosa import TileAttribute as a, Attribute
197
198 self.utype = Attribute.Attribute().TileAttribute
199 self.optFcns = (a.TileAttributeStart, a.TileAttributeEnd)
200
201 self.intvecs.append((a.TileAttributeAddMultiples,
202 multiples))
203
Kevin Cheng77d0f762020-11-24 10:26:32 -0800204 def ResizeAttribute(self, output_size, stride, offset, shift, stride_fp, offset_fp, mode):
Eric Kunzee5e26762020-10-13 16:11:07 -0700205 from tosa import ResizeAttribute as a, Attribute
206
207 self.utype = Attribute.Attribute().ResizeAttribute
208 self.optFcns = (a.ResizeAttributeStart, a.ResizeAttributeEnd)
209
210 self.intvecs.append((a.ResizeAttributeAddOutputSize,
211 output_size))
212 self.intvecs.append((a.ResizeAttributeAddStride,
213 stride))
214 self.intvecs.append((a.ResizeAttributeAddOffset,
215 offset))
216 self.ints.append((a.ResizeAttributeAddShift,
217 shift))
Kevin Cheng77d0f762020-11-24 10:26:32 -0800218 self.fpvecs.append((a.ResizeAttributeAddStrideFp,
219 stride_fp))
220 self.fpvecs.append((a.ResizeAttributeAddOffsetFp,
221 offset_fp))
Eric Kunzee5e26762020-10-13 16:11:07 -0700222 self.ints.append((a.ResizeAttributeAddMode,
223 mode))
224
225 def ClampAttribute(self, minint, maxint, minfp, maxfp):
226 from tosa import ClampAttribute as a, Attribute
227
228 self.utype = Attribute.Attribute().ClampAttribute
229 self.optFcns = (a.ClampAttributeStart, a.ClampAttributeEnd)
230
231 self.ints.append((a.ClampAttributeAddMinInt,
232 minint))
233 self.ints.append((a.ClampAttributeAddMaxInt,
234 maxint))
235
236 self.ints.append((a.ClampAttributeAddMinFp,
237 minfp))
238 self.ints.append((a.ClampAttributeAddMaxFp,
239 maxfp))
240
241 def RescaleAttribute(self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel):
242 from tosa import RescaleAttribute as a, Attribute
243
244 self.utype = Attribute.Attribute().RescaleAttribute
245 self.optFcns = (a.RescaleAttributeStart, a.RescaleAttributeEnd)
246
247 self.ints.append((a.RescaleAttributeAddInputZp,
248 input_zp))
249 self.ints.append((a.RescaleAttributeAddOutputZp,
250 output_zp))
251 self.intvecs.append((a.RescaleAttributeAddMultiplier,
252 multiplier))
253 self.intvecs.append((a.RescaleAttributeAddShift,
254 shift))
255 self.bools.append((a.RescaleAttributeAddScale32,
256 scale32))
257 self.bools.append((a.RescaleAttributeAddDoubleRound,
258 double_round))
259 self.bools.append((a.RescaleAttributeAddPerChannel,
260 per_channel))
261
Kevin Chengaee1fac2020-11-11 13:54:06 -0800262 def MulAttribute(self, shift):
263 from tosa import MulAttribute as a, Attribute
264
265 self.utype = Attribute.Attribute().MulAttribute
266 self.optFcns = (a.MulAttributeStart, a.MulAttributeEnd)
267
268 self.ints.append((a.MulAttributeAddShift,
269 shift))
270
271 def ArithmeticRightShiftAttribute(self, round):
272 from tosa import ArithmeticRightShiftAttribute as a, Attribute
273
274 self.utype = Attribute.Attribute().ArithmeticRightShiftAttribute
275 self.optFcns = (a.ArithmeticRightShiftAttributeStart, a.ArithmeticRightShiftAttributeEnd)
276
277 self.bools.append((a.ArithmeticRightShiftAttributeAddRound,
278 round))
279
Eric Kunzee5e26762020-10-13 16:11:07 -0700280 def CustomAttribute(self, identifier):
281 from tosa import CustomAttribute as a, Attribute
282
283 self.utype = Attribute.Attribute().CustomAttribute
284 self.optFcns = (a.CustomAttributeStart, a.CustomAttributeEnd)
285
286 self.strings.append((a.CustomAttributeAddIdentifier,
287 identifier))
288
289 def CondIfAttribute(self, then_branch, else_branch):
290 from tosa import CondIfAttribute as a, Attribute
291
292 self.utype = Attribute.Attribute().CondIfAttribute
293 self.optFcns = (a.CondIfAttributeStart, a.CondIfAttributeEnd)
294
295 self.strings.append((a.CondIfAttributeAddThenBranch,
296 then_branch))
297 self.strings.append((a.CondIfAttributeAddElseBranch,
298 else_branch))
299
300 def WhileLoopAttribute(self, cond_branch, body_branch):
301 from tosa import WhileLoopAttribute as a, Attribute
302
303 self.utype = Attribute.Attribute().WhileLoopAttribute
304 self.optFcns = (a.WhileLoopAttributeStart, a.WhileLoopAttributeEnd)
305
306 self.strings.append((a.WhileLoopAttributeAddCondBranch,
307 cond_branch))
308 self.strings.append((a.WhileLoopAttributeAddBodyBranch,
309 body_branch))
310
311class TosaSerializerQuantInfo(TosaSerializerUnion):
312 '''This class handles encapsulating all of the enumerated types for quantinfo types'''
313 def __init__(self):
314 super().__init__()
315
316 def ConvQuantInfo(self, input_zp, weight_zp):
317 from tosa import ConvQuantInfo as q, QuantInfo
318
319 self.utype = QuantInfo.QuantInfo().ConvQuantInfo
320 self.optFcns = (q.ConvQuantInfoStart, q.ConvQuantInfoEnd)
321 self.ints.append((q.ConvQuantInfoAddInputZp, input_zp))
322 self.ints.append((q.ConvQuantInfoAddWeightZp, weight_zp))
323
324 def UnaryQuantInfo(self, input_zp, output_zp):
325 from tosa import UnaryQuantInfo as q, QuantInfo
326
327 self.utype = QuantInfo.QuantInfo().UnaryQuantInfo
328 self.optFcns = (q.UnaryQuantInfoStart, q.UnaryQuantInfoEnd)
329 self.ints.append((q.UnaryQuantInfoAddInputZp, input_zp))
330 self.ints.append((q.UnaryQuantInfoAddOutputZp, output_zp))
331
332 def MatMulQuantInfo(self, a_zp, b_zp):
333 from tosa import MatMulQuantInfo as q, QuantInfo
334
335 self.utype = QuantInfo.QuantInfo().MatMulQuantInfo
336 self.optFcns = (q.MatMulQuantInfoStart, q.MatMulQuantInfoEnd)
337 self.ints.append((q.MatMulQuantInfoAddAZp, a_zp))
338 self.ints.append((q.MatMulQuantInfoAddBZp, b_zp))
339
340 def PadQuantInfo(self, input_zp):
341 from tosa import PadQuantInfo as q, QuantInfo
342
343 self.utype = QuantInfo.QuantInfo().PadQuantInfo
344 self.optFcns = (q.PadQuantInfoStart, q.PadQuantInfoEnd)
345 self.ints.append((q.PadQuantInfoAddInputZp, input_zp))
346
347class TosaSerializerTensor:
348 def __init__(self, name, shape, dtype, usage, dformat, filename = None, placeholderFilename = None):
349 self.name = name
350
351 if isinstance(shape, np.ndarray):
352 shape = shape.astype(int).tolist()
353 shape = list(map(int, shape))
354
355 self.shape = shape
356 self.dtype = dtype
357 self.usage = TosaSerializer.toList(usage)
358 self.dformat = TosaSerializer.toList(dformat)
359
360 # Filename for const tensors. This gets written to the .tosa serialization
361 self.filename = filename
362
363 # Filename for placeholder tensors. These get generated by the test generation
364 # process and are written to disk, but are considered input tensors by the network
365 # so they do not appear in the TOSA serialiazation. However, if we want to form a unit
366 # test around these input tensors, we can get the filename from here.
367 self.placeholderFilename = placeholderFilename
368
369 def __str__(self):
370 str = 'TosaSerializerTensor name: {} shape: {} dtype: {} Usage: {} format {} filename: {}'.format(
371 self.name, self.shape, DTypeNames[self.dtype], self.usage, self.dformat, self.filename)
372 return str
373
374 def addUsage(self, usage):
375 self.usage.append(usage)
376
377 def addFormat(self, format):
378 self.dformat.append(format)
379
380 def setDtype(self, dtype):
381 self.dtype = dtype
382
383 def merge(self, name, shape, dtype, usage, dformat, filename = None):
384 # Merge in additional usage/formats to the list
385 found = 0
386 for i in self.usage:
387 if i == usage:
388 found = 1
389 break
390 if not found:
391 self.usage.append(usage)
392
393 found = 0
394 for i in self.dformat:
395 if i == dformat:
396 found = 1
397 break
398 if not found:
399 self.dformat.append(dformat)
400
401 def serialize(self, builder):
402 fb_name = builder.CreateString(self.name)
403 if self.filename:
404 fb_filename = builder.CreateString(self.filename)
405 fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
406 fb_usage = TosaSerializer.serializeInt32Vec(builder, self.usage)
407 fb_dformat = TosaSerializer.serializeInt32Vec(builder, self.dformat)
408
409 TosaTensor.TosaTensorStart(builder)
410 TosaTensor.TosaTensorAddName(builder, fb_name)
411 TosaTensor.TosaTensorAddShape(builder, fb_shapes)
412 TosaTensor.TosaTensorAddType(builder, self.dtype)
413 TosaTensor.TosaTensorAddUsage(builder, fb_usage)
414 TosaTensor.TosaTensorAddFormat(builder, fb_dformat)
415 if self.filename:
416 TosaTensor.TosaTensorAddNpyFilename(builder, fb_filename)
417
418 return TosaTensor.TosaTensorEnd(builder)
419
420class TosaSerializerOperator:
421 def __init__(self, op, inputs, outputs, attributes = None, quantInfo = None):
422 self.op = op
423 self.attributes = attributes
424 self.inputs = TosaSerializer.toList(inputs)
425 self.outputs = TosaSerializer.toList(outputs)
426 self.quantInfo = quantInfo
427
428 def __str__(self):
429 str = 'Op {}\n----\n'.format(self.op)
430
431 for i in self.inputs:
432 str = str + ' Input: {}\n'.format(i)
433 for o in self.outputs:
434 str = str + ' Output: {}\n'.format(o)
435
436 return str
437
438 def serialize(self, builder):
439 fb_inputs = TosaSerializer.serializeStrVec(builder, self.inputs, TosaOperator.TosaOperatorStartInputsVector)
440 fb_outputs = TosaSerializer.serializeStrVec(builder, self.outputs, TosaOperator.TosaOperatorStartOutputsVector)
441 # Need to serialize quant_info and attributes enums still
442 if self.attributes is not None:
443 fb_attributes = self.attributes.serialize(builder)
444
445 if self.quantInfo is not None:
446 fb_qinfo = self.quantInfo.serialize(builder)
447
448 TosaOperator.TosaOperatorStart(builder)
449 TosaOperator.TosaOperatorAddOp(builder, self.op)
450 TosaOperator.TosaOperatorAddInputs(builder, fb_inputs)
451 TosaOperator.TosaOperatorAddOutputs(builder, fb_outputs)
452 if self.attributes is not None:
453 TosaOperator.TosaOperatorAddAttributeType(builder, self.attributes.utype)
454 TosaOperator.TosaOperatorAddAttribute(builder, fb_attributes)
455 if self.quantInfo is not None:
456 TosaOperator.TosaOperatorAddQuantInfoType(builder, self.quantInfo.utype)
457 TosaOperator.TosaOperatorAddQuantInfo(builder, fb_qinfo)
458
459 return TosaOperator.TosaOperatorEnd(builder)
460
461class TosaSerializerBasicBlock:
462 def __init__(self, name):
463 self.name = name
464 self.operators = []
465
466 # Dict assures uniqueness, but allows us to look up by name
467 self.tensors = dict()
468
469 self.inputs = []
470 self.outputs = []
471
472 def addTensor(self, name, shape, dtype, usage, dformat, filename = None, placeholderFilename = None):
473 try:
474 # Someone already added this tensor.
475 # We may have to add more usages and formats
476 tens = self.tensors[name]
477 filename = tens.merge(name, shape, dtype, usage, dformat, filename)
478 except KeyError:
479 self.tensors[name] = TosaSerializerTensor(name, shape, dtype, usage, dformat, filename, placeholderFilename)
480
481 return self.tensors[name]
482
483 def addInput(self, name):
484 self.inputs.append(name)
485
486 def addOutput(self, name):
487 self.outputs.append(name)
488
489 def addOperator(self, op, inputs, outputs, attributes = None, quant_info = None):
490 self.operators.append(TosaSerializerOperator(op, inputs, outputs, attributes, quant_info))
491
492 def serialize(self, builder):
493 fb_name = builder.CreateString(self.name)
494 fbv_inputs = TosaSerializer.serializeStrVec(builder, list(self.inputs), TosaBasicBlock.TosaBasicBlockStartInputsVector)
495 fbv_outputs = TosaSerializer.serializeStrVec(builder, list(self.outputs), TosaBasicBlock.TosaBasicBlockStartOutputsVector)
496 fbv_tensors = TosaSerializer.serializeObjVec(builder, list(self.tensors.values()), TosaBasicBlock.TosaBasicBlockStartTensorsVector)
497 fbv_operators = TosaSerializer.serializeObjVec(builder, self.operators, TosaBasicBlock.TosaBasicBlockStartOperatorsVector)
498
499 TosaBasicBlock.TosaBasicBlockStart(builder)
500 TosaBasicBlock.TosaBasicBlockAddName(builder, fb_name)
501 TosaBasicBlock.TosaBasicBlockAddInputs(builder, fbv_inputs)
502 TosaBasicBlock.TosaBasicBlockAddOutputs(builder, fbv_outputs)
503 TosaBasicBlock.TosaBasicBlockAddTensors(builder, fbv_tensors)
504 TosaBasicBlock.TosaBasicBlockAddOperators(builder, fbv_operators)
505 return TosaBasicBlock.TosaBasicBlockEnd(builder)
506
507@unique
508class TensorDir(IntEnum):
509 PLACEHOLDER = 0
510 CONST = 1
511 INTERMEDIATE = 2
512 RESULT = 3
513
514class TosaSerializer:
515 def __init__(self, pathPrefix):
516
517 # Get the global TOSA version if not already defined
518 try:
519 TOSA_VERSION
520 except NameError:
521 TosaSerializer.setTosaVersion()
522
523 self.builder = flatbuffers.Builder(0)
524
525 self.basicBlocks = []
526 self.startBasicBlock('main')
527 self.pathPrefix = pathPrefix
528
529 # Indicies used for adding/naming tensors
530 self.currInputIdx = 0
531 self.currConstIdx = 0
532 self.currLayerIdx = 1
533 self.currResultIdx = 0
534
535 # Is this an illegal test that is expected to fail?
536 self.expectedFailure = False
537 self.expectedFailureDesc = ''
538
539 def __str__(self):
540 str = ''
541 for bb in self.basicBlocks:
542 str = str + bb.__str__()
543 return str
544
545 def addPlaceholder(self, shape, dtype, usage, dformat, vals):
546 if not self.currBasicBlock:
547 raise Exception('addTensor called without valid basic block')
548
549 name = 'input-{}'.format(self.currInputIdx)
550 filename = '{}.npy'.format(name)
551 self.currInputIdx = self.currInputIdx + 1
552
553 tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, None, filename)
554 # This is always an input to the block
555 self.currBasicBlock.addInput(name)
556 # Add the operator now
557 self.currBasicBlock.addOperator(tosa.Op.Op().PLACEHOLDER, [], name)
558
559 if vals is not None:
560 np.save(os.path.join(self.pathPrefix, filename), vals, False)
561
562 return tens
563
564 def addConst(self, shape, dtype, usage, dformat, vals):
565 if not self.currBasicBlock:
566 raise Exception('addTensor called without valid basic block')
567
568 name = 'const-{}'.format(self.currInputIdx)
569 filename = '{}.npy'.format(name)
570 self.currInputIdx = self.currInputIdx + 1
571
572 tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, filename)
573 # Add the operator now
574 self.currBasicBlock.addOperator(tosa.Op.Op().CONST, [], name)
575
576 if vals is not None:
577 np.save(os.path.join(self.pathPrefix, filename), vals, False)
578 return tens
579
580 def addIntermediate(self, shape, dtype, usage, dformat):
581
582 if not self.currBasicBlock:
583 raise Exception('addTensor called without valid basic block')
584
585 name = 'layer-{}'.format(self.currLayerIdx)
586 filename = None # No file, so no filename
587 self.currLayerIdx = self.currLayerIdx + 1
588
589 tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, filename)
590
591 return tens
592
593 def addInputTensor(self, tensor):
594 self.currBasicBlock.addOperator(tosa.Op.Op().PLACEHOLDER, [], tensor.name)
595 self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype, tensor.usage, tensor.dformat)
596 self.currBasicBlock.addInput(tensor.name)
597
598 def addOutputTensor(self, tensor):
599 self.currBasicBlock.addOutput(tensor.name)
600
601 def addOutput(self, shape, dtype, usage, dformat):
602 if not self.currBasicBlock:
603 raise Exception('addTensor called without valid basic block')
604
605 name = 'result-{}'.format(self.currResultIdx)
606 self.currResultIdx = self.currResultIdx + 1
607
608 tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, None)
609 self.currBasicBlock.addOutput(name)
610 return tens
611
612 def addOperator(self, op, inputs, outputs, attributes = None, quant_info = None):
613
614 if op == tosa.Op.Op().PLACEHOLDER or \
615 op == tosa.Op.Op().CONST:
616 raise Exception('Use addPlaceholderTensor() or addConstTensor() to add PLACEHOLDER and CONST ops')
617
618 return self.currBasicBlock.addOperator(op, inputs, outputs, attributes, quant_info)
619
620 def setExpectedFailure(self, desc='', val=True):
621 self.expectedFailure = val
622 self.expectedFailureDesc = desc
623
624 def setExpectedFailure(self, desc='', val=True):
625 self.expectedFailure = val
626 self.expectedFailureDesc = desc
627
628 def serialize(self):
629
630 builder = self.builder
631
632 Version.VersionStart(builder)
633 Version.VersionAdd_major(builder, TOSA_VERSION[0])
634 Version.VersionAdd_minor(builder, TOSA_VERSION[1])
635 Version.VersionAdd_patch(builder, TOSA_VERSION[2])
636 Version.VersionAdd_experimental(builder, TOSA_VERSION[3])
637 version = Version.VersionEnd(builder)
638
639 fbv_bb = TosaSerializer.serializeObjVec(builder, self.basicBlocks, TosaGraph.TosaGraphStartBlocksVector)
640
641 TosaGraph.TosaGraphStart(builder)
642 TosaGraph.TosaGraphAddVersion(builder, version)
643 TosaGraph.TosaGraphAddBlocks(builder, fbv_bb)
644 graph = TosaGraph.TosaGraphEnd(builder)
645
646 self.builder.Finish(graph)
647 return self.builder.Output()
648
649 def writeJson(self, tosa_filename):
650 '''Write a json test file so that it is fairly easy to pick up the test
651 and generate commands for third party tool'''
652 test_desc = dict()
653
654 test_desc['tosa_file'] = tosa_filename
655 ifm_name = []
656 ifm_shape = []
657 ifm_file = []
658 ofm_name = []
659 ofm_file = []
660 ofm_shape = []
661
662 for b in self.basicBlocks:
663 if b.name == 'main':
664 for i in b.inputs:
665 ifm_name.append(i)
666 ifm_shape.append(b.tensors[i].shape)
667 ifm_file.append(b.tensors[i].placeholderFilename)
668 for o in b.outputs:
669 ofm_name.append(o)
670 ofm_shape.append(b.tensors[o].shape)
671 # Make up an OFM filename here. One isn't generated until the reference tool is
672 # run, so any name is a good name
673 ofm_file.append('ref-{}.npy'.format(o))
674
675 test_desc['ifm_placeholder'] = ifm_name
676 test_desc['ifm_file'] = ifm_file
677 test_desc['ifm_shape'] = ifm_shape
678 test_desc['ofm_name'] = ofm_name
679 test_desc['ofm_shape'] = ofm_shape
680 test_desc['ofm_file'] = ofm_file
681 test_desc['expected_failure'] = self.expectedFailure
682 if self.expectedFailureDesc:
683 test_desc['expected_failure_desc'] = self.expectedFailureDesc
684
685 return json.dumps(test_desc, indent=' ')
686
687 def startBasicBlock(self, name):
688 self.currBasicBlock = TosaSerializerBasicBlock(name)
689 self.basicBlocks.append(self.currBasicBlock)
690
691 @staticmethod
692 def serializeStrVec(builder, vec, start_fcn):
693 fb_strs = [builder.CreateString(i) for i in vec]
694 start_fcn(builder, len(fb_strs))
695 for s in fb_strs[::-1]:
696 builder.PrependUOffsetTRelative(s)
697 return builder.EndVector(len(fb_strs))
698
699 @staticmethod
700 def serializeInt32Vec(builder, vec):
701 builder.StartVector(4, len(vec), 4)
702 for v in vec[::-1]:
703 builder.PrependInt32(v)
704 return builder.EndVector(len(vec))
705
706 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -0800707 def serializeFpVec(builder, vec):
708 builder.StartVector(4, len(vec), 4)
709 for v in vec[::-1]:
710 builder.PrependFloat32(v)
711 return builder.EndVector(len(vec))
712
713 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700714 def serializeObjVec(builder, vec, start_fcn):
715 serialized_vec = []
716 for v in vec[::-1]:
717 serialized_vec.append(v.serialize(builder))
718
719 start_fcn(builder, len(vec))
720 for v in serialized_vec:
721 builder.PrependUOffsetTRelative(v)
722 return builder.EndVector(len(vec))
723
724 @staticmethod
725 def toList(val):
726 if isinstance(val, list):
727 return val
728 else:
729 return [val]
730
731 @staticmethod
732 def setTosaVersion():
733 # Create a dummy flatbuffers file with the default version information
734 # There does not appear to be a better way to get a constant from a
735 # flatbuffer schema file
736 builder = flatbuffers.Builder(0)
737 Version.VersionStart(builder)
738 ver = Version.VersionEnd(builder)
739 TosaGraph.TosaGraphStart(builder)
740 TosaGraph.TosaGraphAddVersion(builder, ver)
741 gr = TosaGraph.TosaGraphEnd(builder)
742 builder.Finish(gr)
743
744 out = builder.Output()
745
746 gr = TosaGraph.TosaGraph()
747 root = gr.GetRootAsTosaGraph(out, 0)
748
749 # Store the version as a global variable so that it only needs to be
750 # generated once per process.
751 global TOSA_VERSION
752 TOSA_VERSION = [root.Version()._major(),
753 root.Version()._minor(),
754 root.Version()._patch(),
755 root.Version()._experimental() ]