blob: d85494d6ec849a5ee028fc4d72f4ec48e9afe974 [file] [log] [blame]
Kevin Chengfea5a372021-10-11 18:38:47 +00001# Copyright (c) 2020-2021, ARM Limited.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15#!/usr/bin/env python3
16
17import os
18import sys
19import json
20import flatbuffers
21import numpy as np
22import struct
23from enum import Enum, IntEnum, unique
24from tosa import (
25 TosaGraph,
26 TosaBasicBlock,
27 TosaTensor,
28 TosaOperator,
29 DType,
30 Op,
31 ResizeMode,
32 Version,
33)
34from tosa_ref_run import TosaReturnCode
35
36import tosa
37
Kevin Chenge6563f52021-10-20 12:12:02 -070038# Keep version number in sync with the version default value with schema/tosa.fbs
Kevin Chengb97cb1d2021-10-14 11:53:39 -070039TOSA_VERSION_MAJOR = 0
40TOSA_VERSION_MINOR = 23
41TOSA_VERSION_PATCH = 0
42TOSA_VERSION_DRAFT = True
43TOSA_VERSION = [TOSA_VERSION_MAJOR,
44 TOSA_VERSION_MINOR,
45 TOSA_VERSION_PATCH,
46 TOSA_VERSION_DRAFT]
Kevin Chengfea5a372021-10-11 18:38:47 +000047# With the way flatc generates its python types, there is no programatic way
48# to get string names for the integer types. Manually maintain a string table
49# here.
50DType = tosa.DType.DType()
51DTypeNames = [
52 "UNKNOWN",
53 "BOOL",
54 "UINT8",
55 "INT4",
56 "INT8",
57 "INT16",
58 "INT32",
59 "INT48",
60 "FLOAT",
61]
62
63ByteMask = np.uint64(0xFF)
64
65
66def dtype_str_to_val(name):
67
68 for i in range(len(DTypeNames)):
69 if name.casefold() == DTypeNames[i].casefold():
70 return i
71 raise Exception("Unable to parse DType name {}".format(name))
72
73
74class TosaSerializerUnion:
75 """This class handles encapsulating and serializing union types into flatbuffers"""
76
77 def __init__(self):
78
79 # A tuple of the start and end functions. Set by the options constructors below
80 self.optFcns = None
81
82 # The type from the tosa.Options enumeration. Set by the options constructors below.
83 self.utype = None
84
85 # Each of these lists is a tuple of the add function and the
86 # value being added. Set by the options constructors below.
87 self.ints = []
88 self.bools = []
89 self.floats = []
90 self.strings = []
91 self.intvecs = []
92 self.fpvecs = []
93
94 def serialize(self, builder):
95
96 # We have to build strings and vectors first
97 strList = []
98 intVecList = []
99 fpVecList = []
100
101 for fcn, val in self.strings:
102 strList.append((fcn, builder.CreateString(val)))
103
104 for fcn, val in self.intvecs:
105 intVecList.append((fcn, TosaSerializer.serializeInt32Vec(builder, val)))
106
107 for fcn, val in self.fpvecs:
108 fpVecList.append((fcn, TosaSerializer.serializeFpVec(builder, val)))
109
110 startFcn, endFcn = self.optFcns
111
112 # Then serialize the options object from the list of primitives and
113 # other serialized values
114 startFcn(builder)
115 for fcn, val in self.ints:
116 fcn(builder, val)
117
118 for fcn, val in self.bools:
119 fcn(builder, val)
120
121 for fcn, val in self.floats:
122 fcn(builder, val)
123
124 for fcn, val in strList:
125 fcn(builder, val)
126
127 for fcn, val in intVecList:
128 fcn(builder, val)
129
130 for fcn, val in fpVecList:
131 fcn(builder, val)
132
133 return endFcn(builder)
134
135
136class TosaSerializerAttribute(TosaSerializerUnion):
137 """This class handles encapsulating all of the enumerated types for attributes"""
138
139 def __init__(self):
140 super().__init__()
141
142 def PoolAttribute(self, kernel, stride, padding):
143 from tosa import PoolAttribute as a, Attribute
144
145 self.utype = Attribute.Attribute().PoolAttribute
146
147 self.optFcns = (a.PoolAttributeStart, a.PoolAttributeEnd)
148 self.intvecs.append((a.PoolAttributeAddPadding, padding))
149 self.intvecs.append((a.PoolAttributeAddKernel, kernel))
150 self.intvecs.append((a.PoolAttributeAddStride, stride))
151
152 def ConvAttribute(self, padding, stride, dilation):
153 from tosa import ConvAttribute as a, Attribute
154
155 self.utype = Attribute.Attribute().ConvAttribute
156 self.optFcns = (a.ConvAttributeStart, a.ConvAttributeEnd)
157
158 self.intvecs.append((a.ConvAttributeAddPadding, padding))
159 self.intvecs.append((a.ConvAttributeAddStride, stride))
160 self.intvecs.append((a.ConvAttributeAddDilation, dilation))
161
162 def TransposeConvAttribute(self, outpad, stride, dilation, output_shape):
163 from tosa import TransposeConvAttribute as a, Attribute
164
165 self.utype = Attribute.Attribute().TransposeConvAttribute
166 self.optFcns = (a.TransposeConvAttributeStart, a.TransposeConvAttributeEnd)
167
168 self.intvecs.append((a.TransposeConvAttributeAddOutpad, outpad))
169 self.intvecs.append((a.TransposeConvAttributeAddStride, stride))
170 self.intvecs.append((a.TransposeConvAttributeAddDilation, dilation))
171 self.intvecs.append((a.TransposeConvAttributeAddOutputShape, output_shape))
172
173 def ReluNAttribute(self, maxint, maxfp):
174 from tosa import ReluNAttribute as a, Attribute
175
176 self.utype = Attribute.Attribute().ReluNAttribute
177 self.optFcns = (a.ReluNAttributeStart, a.ReluNAttributeEnd)
178
179 self.ints.append((a.ReluNAttributeAddMaxInt, maxint))
180 self.ints.append((a.ReluNAttributeAddMaxFp, maxfp))
181
182 def AxisAttribute(self, axis):
183 from tosa import AxisAttribute as a, Attribute
184
185 self.utype = Attribute.Attribute().AxisAttribute
186 self.optFcns = (a.AxisAttributeStart, a.AxisAttributeEnd)
187
188 self.ints.append((a.AxisAttributeAddAxis, axis))
189
190 def ReshapeAttribute(self, shape):
191 from tosa import ReshapeAttribute as a, Attribute
192
193 self.utype = Attribute.Attribute().ReshapeAttribute
194 self.optFcns = (a.ReshapeAttributeStart, a.ReshapeAttributeEnd)
195
196 self.intvecs.append((a.ReshapeAttributeAddShape, shape))
197
198 def SliceAttribute(self, begin, size):
199 from tosa import SliceAttribute as a, Attribute
200
201 self.utype = Attribute.Attribute().SliceAttribute
202 self.optFcns = (a.SliceAttributeStart, a.SliceAttributeEnd)
203
204 self.intvecs.append((a.SliceAttributeAddBegin, begin))
205 self.intvecs.append((a.SliceAttributeAddSize, size))
206
207 def TileAttribute(self, multiples):
208 from tosa import TileAttribute as a, Attribute
209
210 self.utype = Attribute.Attribute().TileAttribute
211 self.optFcns = (a.TileAttributeStart, a.TileAttributeEnd)
212
213 self.intvecs.append((a.TileAttributeAddMultiples, multiples))
214
215 def ResizeAttribute(
216 self, output_size, stride, offset, shift, stride_fp, offset_fp, mode
217 ):
218 from tosa import ResizeAttribute as a, Attribute
219
220 self.utype = Attribute.Attribute().ResizeAttribute
221 self.optFcns = (a.ResizeAttributeStart, a.ResizeAttributeEnd)
222
223 self.intvecs.append((a.ResizeAttributeAddOutputSize, output_size))
224 self.intvecs.append((a.ResizeAttributeAddStride, stride))
225 self.intvecs.append((a.ResizeAttributeAddOffset, offset))
226 self.ints.append((a.ResizeAttributeAddShift, shift))
227 self.fpvecs.append((a.ResizeAttributeAddStrideFp, stride_fp))
228 self.fpvecs.append((a.ResizeAttributeAddOffsetFp, offset_fp))
229 self.ints.append((a.ResizeAttributeAddMode, mode))
230
231 def ClampAttribute(self, minint, maxint, minfp, maxfp):
232 from tosa import ClampAttribute as a, Attribute
233
234 self.utype = Attribute.Attribute().ClampAttribute
235 self.optFcns = (a.ClampAttributeStart, a.ClampAttributeEnd)
236
237 self.ints.append((a.ClampAttributeAddMinInt, minint))
238 self.ints.append((a.ClampAttributeAddMaxInt, maxint))
239
240 self.ints.append((a.ClampAttributeAddMinFp, minfp))
241 self.ints.append((a.ClampAttributeAddMaxFp, maxfp))
242
243 def RescaleAttribute(
244 self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel
245 ):
246 from tosa import RescaleAttribute as a, Attribute
247
248 self.utype = Attribute.Attribute().RescaleAttribute
249 self.optFcns = (a.RescaleAttributeStart, a.RescaleAttributeEnd)
250
251 self.ints.append((a.RescaleAttributeAddInputZp, input_zp))
252 self.ints.append((a.RescaleAttributeAddOutputZp, output_zp))
253 self.intvecs.append((a.RescaleAttributeAddMultiplier, multiplier))
254 self.intvecs.append((a.RescaleAttributeAddShift, shift))
255 self.bools.append((a.RescaleAttributeAddScale32, scale32))
256 self.bools.append((a.RescaleAttributeAddDoubleRound, double_round))
257 self.bools.append((a.RescaleAttributeAddPerChannel, per_channel))
258
259 def MulAttribute(self, shift):
260 from tosa import MulAttribute as a, Attribute
261
262 self.utype = Attribute.Attribute().MulAttribute
263 self.optFcns = (a.MulAttributeStart, a.MulAttributeEnd)
264
265 self.ints.append((a.MulAttributeAddShift, shift))
266
267 def ArithmeticRightShiftAttribute(self, round):
268 from tosa import ArithmeticRightShiftAttribute as a, Attribute
269
270 self.utype = Attribute.Attribute().ArithmeticRightShiftAttribute
271 self.optFcns = (
272 a.ArithmeticRightShiftAttributeStart,
273 a.ArithmeticRightShiftAttributeEnd,
274 )
275
276 self.bools.append((a.ArithmeticRightShiftAttributeAddRound, round))
277
278 def CustomAttribute(self, identifier):
279 from tosa import CustomAttribute as a, Attribute
280
281 self.utype = Attribute.Attribute().CustomAttribute
282 self.optFcns = (a.CustomAttributeStart, a.CustomAttributeEnd)
283
284 self.strings.append((a.CustomAttributeAddIdentifier, identifier))
285
286 def CondIfAttribute(self, then_branch, else_branch):
287 from tosa import CondIfAttribute as a, Attribute
288
289 self.utype = Attribute.Attribute().CondIfAttribute
290 self.optFcns = (a.CondIfAttributeStart, a.CondIfAttributeEnd)
291
292 self.strings.append((a.CondIfAttributeAddThenBranch, then_branch))
293 self.strings.append((a.CondIfAttributeAddElseBranch, else_branch))
294
295 def WhileLoopAttribute(self, cond_branch, body_branch):
296 from tosa import WhileLoopAttribute as a, Attribute
297
298 self.utype = Attribute.Attribute().WhileLoopAttribute
299 self.optFcns = (a.WhileLoopAttributeStart, a.WhileLoopAttributeEnd)
300
301 self.strings.append((a.WhileLoopAttributeAddCondBranch, cond_branch))
302 self.strings.append((a.WhileLoopAttributeAddBodyBranch, body_branch))
303
304
305class TosaSerializerQuantInfo(TosaSerializerUnion):
306 """This class handles encapsulating all of the enumerated types for quantinfo types"""
307
308 def __init__(self):
309 super().__init__()
310
311 def ConvQuantInfo(self, input_zp, weight_zp):
312 from tosa import ConvQuantInfo as q, QuantInfo
313
314 self.utype = QuantInfo.QuantInfo().ConvQuantInfo
315 self.optFcns = (q.ConvQuantInfoStart, q.ConvQuantInfoEnd)
316 self.ints.append((q.ConvQuantInfoAddInputZp, input_zp))
317 self.ints.append((q.ConvQuantInfoAddWeightZp, weight_zp))
318
319 def UnaryQuantInfo(self, input_zp, output_zp):
320 from tosa import UnaryQuantInfo as q, QuantInfo
321
322 self.utype = QuantInfo.QuantInfo().UnaryQuantInfo
323 self.optFcns = (q.UnaryQuantInfoStart, q.UnaryQuantInfoEnd)
324 self.ints.append((q.UnaryQuantInfoAddInputZp, input_zp))
325 self.ints.append((q.UnaryQuantInfoAddOutputZp, output_zp))
326
327 def MatMulQuantInfo(self, a_zp, b_zp):
328 from tosa import MatMulQuantInfo as q, QuantInfo
329
330 self.utype = QuantInfo.QuantInfo().MatMulQuantInfo
331 self.optFcns = (q.MatMulQuantInfoStart, q.MatMulQuantInfoEnd)
332 self.ints.append((q.MatMulQuantInfoAddAZp, a_zp))
333 self.ints.append((q.MatMulQuantInfoAddBZp, b_zp))
334
335 def PadQuantInfo(self, input_zp):
336 from tosa import PadQuantInfo as q, QuantInfo
337
338 self.utype = QuantInfo.QuantInfo().PadQuantInfo
339 self.optFcns = (q.PadQuantInfoStart, q.PadQuantInfoEnd)
340 self.ints.append((q.PadQuantInfoAddInputZp, input_zp))
341
342
343class TosaSerializerTensor:
344 def __init__(
345 self,
346 name,
347 shape,
348 dtype,
349 data=None,
350 placeholderFilename=None,
351 ):
352 self.name = name
353
354 if isinstance(shape, np.ndarray):
355 shape = shape.astype(int).tolist()
356 shape = list(map(int, shape))
357
358 self.shape = shape
359 self.dtype = dtype
360
361 if isinstance(data, np.ndarray):
362 data = data.flatten().astype(int).tolist()
363 data = list(map(int, data))
364 self.data = data
365 elif isinstance(data, list):
366 data = list(map(int, data))
367 self.data = data
368 else:
369 self.data = None
370
371 # Filename for placeholder tensors. These get generated by the test generation
372 # process and are written to disk, but are considered input tensors by the network
373 # so they do not appear in the TOSA serialiazation. However, if we want to form a unit
374 # test around these input tensors, we can get the filename from here.
375 self.placeholderFilename = placeholderFilename
376
377 def __str__(self):
378 str = "TosaSerializerTensor name: {} shape: {} dtype: {}".format(
379 self.name,
380 self.shape,
381 DTypeNames[self.dtype],
382 )
383 return str
384
385 def setDtype(self, dtype):
386 self.dtype = dtype
387
388 def serialize(self, builder):
389 fb_name = builder.CreateString(self.name)
390 fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
391 if self.data:
392 u8_data = list()
393 # little endianess
394 if self.dtype == DType.BOOL:
395 for val in self.data:
396 val_u8 = np.uint8(val)
397 u8_data.append(val_u8)
398 elif self.dtype == DType.INT4:
399 in_size = len(self.data)
400 out_size = (in_size + 1) // 2
401 for i in range(out_size):
402 val_0 = self.data[2 * i]
403 if (2 * i + 1) < in_size:
404 val_1 = self.data[2 * i + 1]
405 else:
406 val_1 = 0
407 val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4)
408 val_u8 = np.uint8(val_i8)
409 u8_data.append(val_u8)
410 elif self.dtype == DType.INT8:
411 for val in self.data:
412 val_u8 = np.uint8(val)
413 u8_data.append(val_u8)
414 elif self.dtype == DType.INT16:
415 for val in self.data:
416 val_u16 = np.uint16(val)
417 b0 = val_u16 & ByteMask
418 b1 = (val_u16 >> np.uint16(8)) & ByteMask
419 u8_data.extend([b0, b1])
420 elif self.dtype == DType.INT32:
421 for val in self.data:
422 val_u32 = np.uint32(val)
423 b0 = val_u32 & ByteMask
424 b1 = (val_u32 >> np.uint32(8)) & ByteMask
425 b2 = (val_u32 >> np.uint32(16)) & ByteMask
Kevin Cheng6b078ca2021-10-13 23:12:50 -0700426 b3 = (val_u32 >> np.uint32(24)) & ByteMask
Kevin Chengfea5a372021-10-11 18:38:47 +0000427 u8_data.extend([b0, b1, b2, b3])
428 elif self.dtype == DType.INT48:
429 for val in self.data:
430 val_u64 = np.uint64(val)
431 b0 = val_u64 & ByteMask
432 b1 = (val_u64 >> np.uint64(8)) & ByteMask
433 b2 = (val_u64 >> np.uint64(16)) & ByteMask
434 b3 = (val_u64 >> np.uint64(24)) & ByteMask
435 b4 = (val_u64 >> np.uint64(32)) & ByteMask
436 b5 = (val_u64 >> np.uint64(40)) & ByteMask
437 u8_data.extend([b0, b1, b2, b3, b4, b5])
438 elif self.dtype == DType.FLOAT:
439 for val in self.data:
440 b = struct.pack("!f", val)
441 u8_data.extend([b[3], b[2], b[1], b[0]])
442 else:
443 raise Exception(
444 "unsupported data type {}".format(DTypeNames[self.dtype])
445 )
446 fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data)
447
448 TosaTensor.TosaTensorStart(builder)
449 TosaTensor.TosaTensorAddName(builder, fb_name)
450 TosaTensor.TosaTensorAddShape(builder, fb_shapes)
451 TosaTensor.TosaTensorAddType(builder, self.dtype)
452 if self.data:
453 TosaTensor.TosaTensorAddData(builder, fb_data)
454
455 return TosaTensor.TosaTensorEnd(builder)
456
457
458class TosaSerializerOperator:
459 def __init__(self, op, inputs, outputs, attributes=None, quantInfo=None):
460 self.op = op
461 self.attributes = attributes
462 self.inputs = TosaSerializer.toList(inputs)
463 self.outputs = TosaSerializer.toList(outputs)
464 self.quantInfo = quantInfo
465
466 def __str__(self):
467 str = "Op {}\n----\n".format(self.op)
468
469 for i in self.inputs:
470 str = str + " Input: {}\n".format(i)
471 for o in self.outputs:
472 str = str + " Output: {}\n".format(o)
473
474 return str
475
476 def serialize(self, builder):
477 fb_inputs = TosaSerializer.serializeStrVec(
478 builder, self.inputs, TosaOperator.TosaOperatorStartInputsVector
479 )
480 fb_outputs = TosaSerializer.serializeStrVec(
481 builder, self.outputs, TosaOperator.TosaOperatorStartOutputsVector
482 )
483 # Need to serialize quant_info and attributes enums still
484 if self.attributes is not None:
485 fb_attributes = self.attributes.serialize(builder)
486
487 if self.quantInfo is not None:
488 fb_qinfo = self.quantInfo.serialize(builder)
489
490 TosaOperator.TosaOperatorStart(builder)
491 TosaOperator.TosaOperatorAddOp(builder, self.op)
492 TosaOperator.TosaOperatorAddInputs(builder, fb_inputs)
493 TosaOperator.TosaOperatorAddOutputs(builder, fb_outputs)
494 if self.attributes is not None:
495 TosaOperator.TosaOperatorAddAttributeType(builder, self.attributes.utype)
496 TosaOperator.TosaOperatorAddAttribute(builder, fb_attributes)
497 if self.quantInfo is not None:
498 TosaOperator.TosaOperatorAddQuantInfoType(builder, self.quantInfo.utype)
499 TosaOperator.TosaOperatorAddQuantInfo(builder, fb_qinfo)
500
501 return TosaOperator.TosaOperatorEnd(builder)
502
503
504class TosaSerializerBasicBlock:
505 def __init__(self, name):
506 self.name = name
507 self.operators = []
508
509 # Dict assures uniqueness, but allows us to look up by name
510 self.tensors = dict()
511
512 self.inputs = []
513 self.outputs = []
514
515 def addTensor(
516 self,
517 name,
518 shape,
519 dtype,
520 data=None,
521 placeholderFilename=None,
522 ):
523 try:
524 # Someone already added this tensor.
525 tens = self.tensors[name]
526 except KeyError:
527 self.tensors[name] = TosaSerializerTensor(
528 name, shape, dtype, data, placeholderFilename
529 )
530
531 return self.tensors[name]
532
533 def addInput(self, name):
534 self.inputs.append(name)
535
536 def addOutput(self, name):
537 self.outputs.append(name)
538
539 def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None):
540 self.operators.append(
541 TosaSerializerOperator(op, inputs, outputs, attributes, quant_info)
542 )
543
544 def serialize(self, builder):
545 fb_name = builder.CreateString(self.name)
546 fbv_inputs = TosaSerializer.serializeStrVec(
547 builder, list(self.inputs), TosaBasicBlock.TosaBasicBlockStartInputsVector
548 )
549 fbv_outputs = TosaSerializer.serializeStrVec(
550 builder, list(self.outputs), TosaBasicBlock.TosaBasicBlockStartOutputsVector
551 )
552 fbv_tensors = TosaSerializer.serializeObjVec(
553 builder,
554 list(self.tensors.values()),
555 TosaBasicBlock.TosaBasicBlockStartTensorsVector,
556 )
557 fbv_operators = TosaSerializer.serializeObjVec(
558 builder, self.operators, TosaBasicBlock.TosaBasicBlockStartOperatorsVector
559 )
560
561 TosaBasicBlock.TosaBasicBlockStart(builder)
562 TosaBasicBlock.TosaBasicBlockAddName(builder, fb_name)
563 TosaBasicBlock.TosaBasicBlockAddInputs(builder, fbv_inputs)
564 TosaBasicBlock.TosaBasicBlockAddOutputs(builder, fbv_outputs)
565 TosaBasicBlock.TosaBasicBlockAddTensors(builder, fbv_tensors)
566 TosaBasicBlock.TosaBasicBlockAddOperators(builder, fbv_operators)
567 return TosaBasicBlock.TosaBasicBlockEnd(builder)
568
569
570@unique
571class TensorDir(IntEnum):
572 PLACEHOLDER = 0
573 CONST = 1
574 INTERMEDIATE = 2
575 RESULT = 3
576
577
578class TosaSerializer:
579 def __init__(self, pathPrefix):
580
581 # Get the global TOSA version if not already defined
Kevin Chengfea5a372021-10-11 18:38:47 +0000582
583 self.builder = flatbuffers.Builder(0)
584
585 self.basicBlocks = []
586 self.startBasicBlock("main")
587 self.pathPrefix = pathPrefix
588
589 # Indicies used for adding/naming tensors
590 self.currInputIdx = 0
591 self.currConstIdx = 0
592 self.currLayerIdx = 1
593 self.currResultIdx = 0
594
595 # Is this an illegal test that is expected to fail?
596 self.expectedReturnCode = TosaReturnCode.VALID
597 self.expectedFailure = False
598 self.expectedFailureDesc = ""
599
600 def __str__(self):
601 str = ""
602 for bb in self.basicBlocks:
603 str = str + bb.__str__()
604 return str
605
606 def addPlaceholder(self, shape, dtype, vals):
607 if not self.currBasicBlock:
608 raise Exception("addTensor called without valid basic block")
609
610 name = "input-{}".format(self.currInputIdx)
611 filename = "{}.npy".format(name)
612 self.currInputIdx = self.currInputIdx + 1
613
614 tens = self.currBasicBlock.addTensor(name, shape, dtype, None, filename)
615 # This is always an input to the block
616 self.currBasicBlock.addInput(name)
617
618 if vals is not None:
619 np.save(os.path.join(self.pathPrefix, filename), vals, False)
620
621 return tens
622
623 def addConst(self, shape, dtype, vals):
624 if not self.currBasicBlock:
625 raise Exception("addTensor called without valid basic block")
626
627 name = "const-{}".format(self.currInputIdx)
628 filename = "{}.npy".format(name)
629 self.currInputIdx = self.currInputIdx + 1
630
631 tens = self.currBasicBlock.addTensor(name, shape, dtype, vals)
632 # Add the operator now
633 self.currBasicBlock.addOperator(tosa.Op.Op().CONST, [], name)
634
635 return tens
636
637 def addIntermediate(self, shape, dtype):
638
639 if not self.currBasicBlock:
640 raise Exception("addTensor called without valid basic block")
641
642 name = "layer-{}".format(self.currLayerIdx)
643 self.currLayerIdx = self.currLayerIdx + 1
644
645 tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
646
647 return tens
648
649 def addInputTensor(self, tensor):
650 self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype)
651 self.currBasicBlock.addInput(tensor.name)
652
653 def addOutputTensor(self, tensor):
654 self.currBasicBlock.addOutput(tensor.name)
655
656 def addOutput(self, shape, dtype):
657 if not self.currBasicBlock:
658 raise Exception("addTensor called without valid basic block")
659
660 name = "result-{}".format(self.currResultIdx)
661 self.currResultIdx = self.currResultIdx + 1
662
663 tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
664 self.currBasicBlock.addOutput(name)
665 return tens
666
667 def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None):
668
669 if op == tosa.Op.Op().CONST:
670 raise Exception("Use addConstTensor() to add CONST ops")
671
672 return self.currBasicBlock.addOperator(
673 op, inputs, outputs, attributes, quant_info
674 )
675
676 def setExpectedReturnCode(self, val, desc=""):
677
678 self.expectedReturnCode = val
679 self.expectedFailureDesc = desc
680
681 if val == TosaReturnCode.VALID:
682 self.expectedFailure = False
683 else:
684 # Unpredictable or error results are considered expected failures
685 # for conformance
686 self.expectedFailure = True
687
688 def serialize(self):
689
690 builder = self.builder
691
692 Version.VersionStart(builder)
693 Version.VersionAdd_major(builder, TOSA_VERSION[0])
694 Version.VersionAdd_minor(builder, TOSA_VERSION[1])
695 Version.VersionAdd_patch(builder, TOSA_VERSION[2])
Kevin Chengb97cb1d2021-10-14 11:53:39 -0700696 Version.VersionAdd_draft(builder, TOSA_VERSION[3])
Kevin Chengfea5a372021-10-11 18:38:47 +0000697 version = Version.VersionEnd(builder)
698
699 fbv_bb = TosaSerializer.serializeObjVec(
700 builder, self.basicBlocks, TosaGraph.TosaGraphStartBlocksVector
701 )
702
703 TosaGraph.TosaGraphStart(builder)
704 TosaGraph.TosaGraphAddVersion(builder, version)
705 TosaGraph.TosaGraphAddBlocks(builder, fbv_bb)
706 graph = TosaGraph.TosaGraphEnd(builder)
707
708 self.builder.Finish(graph)
709 return self.builder.Output()
710
711 def writeJson(self, tosa_filename):
712 """Write a json test file so that it is fairly easy to pick up the test
713 and generate commands for third party tool"""
714 test_desc = dict()
715
716 test_desc["tosa_file"] = tosa_filename
717 ifm_name = []
718 ifm_file = []
719 ofm_name = []
720 ofm_file = []
721
722 for b in self.basicBlocks:
723 if b.name == "main":
724 for i in b.inputs:
725 ifm_name.append(i)
726 ifm_file.append(b.tensors[i].placeholderFilename)
727 for o in b.outputs:
728 ofm_name.append(o)
729 # Make up an OFM filename here. One isn't generated until the reference tool is
730 # run, so any name is a good name
731 ofm_file.append("ref-{}.npy".format(o))
732
733 test_desc["ifm_name"] = ifm_name
734 test_desc["ifm_file"] = ifm_file
735 test_desc["ofm_name"] = ofm_name
736 test_desc["ofm_file"] = ofm_file
737 test_desc["expected_return_code"] = self.expectedReturnCode
738 test_desc["expected_failure"] = self.expectedFailure
739 if self.expectedFailureDesc:
740 test_desc["expected_failure_desc"] = self.expectedFailureDesc
741
742 return json.dumps(test_desc, indent=" ")
743
744 def startBasicBlock(self, name):
745 self.currBasicBlock = TosaSerializerBasicBlock(name)
746 self.basicBlocks.append(self.currBasicBlock)
747
748 @staticmethod
749 def serializeStrVec(builder, vec, start_fcn):
750 fb_strs = [builder.CreateString(i) for i in vec]
751 start_fcn(builder, len(fb_strs))
752 for s in fb_strs[::-1]:
753 builder.PrependUOffsetTRelative(s)
754 # This try/except block supports both the Flatbuffers 2.x and 1.x APIs,
755 # defaulting to 2.x. If/when Flatbuffers 1.x support is deprecated, the
756 # try block and builder.EndVector(len) function calls can be removed.
757 try:
758 return builder.EndVector()
759 except TypeError:
760 return builder.EndVector(len(fb_strs))
761
762 @staticmethod
763 def serializeUint8Vec(builder, vec):
764 builder.StartVector(1, len(vec), 8)
765 for v in vec[::-1]:
766 builder.PrependUint8(v)
767 try:
768 return builder.EndVector()
769 except TypeError:
770 return builder.EndVector(len(vec))
771
772 @staticmethod
773 def serializeInt32Vec(builder, vec):
774 builder.StartVector(4, len(vec), 4)
775 for v in vec[::-1]:
776 builder.PrependInt32(v)
777 try:
778 return builder.EndVector()
779 except TypeError:
780 return builder.EndVector(len(vec))
781
782 @staticmethod
783 def serializeFpVec(builder, vec):
784 builder.StartVector(4, len(vec), 4)
785 for v in vec[::-1]:
786 builder.PrependFloat32(v)
787 try:
788 return builder.EndVector()
789 except TypeError:
790 return builder.EndVector(len(vec))
791
792 @staticmethod
793 def serializeObjVec(builder, vec, start_fcn):
794 serialized_vec = []
795 for v in vec[::-1]:
796 serialized_vec.append(v.serialize(builder))
797
798 start_fcn(builder, len(vec))
799 for v in serialized_vec:
800 builder.PrependUOffsetTRelative(v)
801 try:
802 return builder.EndVector()
803 except TypeError:
804 return builder.EndVector(len(vec))
805
806 @staticmethod
807 def toList(val):
808 if isinstance(val, list):
809 return val
810 else:
811 return [val]
812