blob: ba31c2dc47f79a1d116b0d9aca3efc43930077b4 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
35
Kevin Cheng550ccc52021-03-03 11:21:43 -080036# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
37parent_dir = os.path.dirname(os.path.realpath(__file__))
38sys.path.append(
39 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
40)
Eric Kunzee5e26762020-10-13 16:11:07 -070041import tosa_serializer as ts
42from tosa_serializer import *
43import tosa
44
45# Convenience variables to the flatc-generated types that should be enums, but aren't
46DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080047Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070048ResizeMode = tosa.ResizeMode.ResizeMode()
49
Kevin Cheng550ccc52021-03-03 11:21:43 -080050
Eric Kunzee5e26762020-10-13 16:11:07 -070051class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080052 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
53
Eric Kunzee5e26762020-10-13 16:11:07 -070054 def __init__(self):
55 pass
56
57 @staticmethod
58 def needsQinfo(op, dtype):
Jared Smolens2a76ad22021-03-04 11:18:54 -080059 if dtype == DType.INT8 or dtype == DType.INT16:
Eric Kunzee5e26762020-10-13 16:11:07 -070060 return True
61 return False
62
63 @staticmethod
64 def qgUnary(testGen, op, dtype):
65 qinfo = ts.TosaSerializerQuantInfo()
66 if TosaQuantGen.needsQinfo(op, dtype):
67 qinfo.UnaryQuantInfo(testGen.randInt(), testGen.randInt())
68 else:
69 qinfo.UnaryQuantInfo(0, 0)
70 return qinfo
71
72 @staticmethod
73 def qgConv(testGen, op, dtype):
74 qinfo = ts.TosaSerializerQuantInfo()
75 if TosaQuantGen.needsQinfo(op, dtype):
76 qinfo.ConvQuantInfo(testGen.randInt(), testGen.randInt())
77 else:
78 qinfo.ConvQuantInfo(0, 0)
79 return qinfo
80
81 @staticmethod
82 def qgMatmul(testGen, op, dtype):
83 qinfo = ts.TosaSerializerQuantInfo()
84 if TosaQuantGen.needsQinfo(op, dtype):
85 qinfo.MatMulQuantInfo(testGen.randInt(), testGen.randInt())
86 else:
87 qinfo.MatMulQuantInfo(0, 0)
88 return qinfo
89
90 @staticmethod
91 def qgPad(testGen, op, dtype):
92 qinfo = ts.TosaSerializerQuantInfo()
93 if TosaQuantGen.needsQinfo(op, dtype):
94 qinfo.PadQuantInfo(testGen.randInt())
95 else:
96 qinfo.PadQuantInfo(0)
97 return qinfo
98
99 @staticmethod
100 def computeMultiplierAndShift(scaleFp, scale32):
101 # Derived from computeMultiplierAndShiftTosaScale32
102 # Provide a floating-point scaling factor and the scale32 parameter
103 # to compute the multiplier and shift
104
105 if scale32:
106 scaleBits = 31
107 else:
108 scaleBits = 15
109
110 m, shift = math.frexp(scaleFp)
111
112 if scaleFp < 0.0:
113 m = -m
114
115 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800116 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700117
118 if multiplier == (1 << scaleBits):
119 multiplier = multiplier // 2
120 shift = shift + 1
121
122 shift = (-shift) + scaleBits
Kevin Cheng550ccc52021-03-03 11:21:43 -0800123 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
Eric Kunzee5e26762020-10-13 16:11:07 -0700124
Kevin Cheng550ccc52021-03-03 11:21:43 -0800125 assert multiplier <= (1 << scaleBits)
126 assert shift >= 0 and shift <= 63
Eric Kunzee5e26762020-10-13 16:11:07 -0700127
128 return multiplier, shift
129
130
Kevin Cheng550ccc52021-03-03 11:21:43 -0800131class TosaTensorGen:
132 """Tensor generators create a shape list for the placeholder and const tensor
133 data operands for the operator. The actual random data is generated separately for each test."""
134
Eric Kunzee5e26762020-10-13 16:11:07 -0700135 def __init__(self):
136 pass
137
138 @staticmethod
139 def tgBasic(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800140 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 shape = testGen.makeShape(rank)
142
143 shape_list = []
144 for i in range(pl + const):
145 shape_list.append(shape.copy())
146
147 return shape_list
148
149 @staticmethod
150 def tgNHWC(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800151 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700152
Kevin Cheng550ccc52021-03-03 11:21:43 -0800153 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700154
155 shape = testGen.makeShape(rank)
156
157 # Constrict the batch size?
158 if testGen.args.max_batch_size:
159 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
160
161 shape_list = []
162 for i in range(pl + const):
163 shape_list.append(shape.copy())
164
165 return shape_list
166
167 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -0800168 def tgScatter(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800169 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800170
Kevin Cheng550ccc52021-03-03 11:21:43 -0800171 assert pl == 2
172 assert const == 0
173 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800174
175 values_in_shape = testGen.makeShape(rank)
176
177 # Constrict the batch size?
178 if testGen.args.max_batch_size:
179 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
180
Kevin Cheng550ccc52021-03-03 11:21:43 -0800181 W = testGen.randInt(
182 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
183 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800184 input_shape = [values_in_shape[0], W, values_in_shape[2]]
185
186 shape_list = []
187 shape_list.append(values_in_shape.copy())
188 shape_list.append(input_shape.copy())
189
190 return shape_list
191
192 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 def tgBroadcastFuzz(testGen, op, rank):
194 shape = testGen.makeShape(rank)
195
Kevin Cheng550ccc52021-03-03 11:21:43 -0800196 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700197
198 shape_list = []
199
200 # Choose one of the inputs to broadcast
201 bcast_idx = testGen.randInt(0, pl + const)
202 for i in range(pl + const):
203 shape_bcast = shape.copy()
204
205 # If the chosen input, pick a random index to broadcast
206 if i == bcast_idx:
207 fuzz_idx = testGen.randInt(0, rank)
208 shape_bcast[fuzz_idx] = 1
209
210 shape_list.append(shape_bcast)
211
212 return shape_list
213
214 @staticmethod
215 def tgConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800216 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700217
Kevin Cheng550ccc52021-03-03 11:21:43 -0800218 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700219
220 # IFM dimensions are NHWC
221 ifm_shape = testGen.makeShape(rank)
222
223 # Constrict the batch size?
224 if testGen.args.max_batch_size:
225 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
226
227 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800228 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700229
230 # Generate a random OFM depth
231 ofm_depth = testGen.makeShape(1)[0]
232
233 # The filter dimensions are OHWI
234 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
235
236 # The bias is OC
237 bias_shape = np.asarray([ofm_depth])
238
239 return [ifm_shape, filter_shape, bias_shape]
240
241 @staticmethod
242 def tgTransposeConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800243 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700244
Kevin Cheng550ccc52021-03-03 11:21:43 -0800245 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700246
247 # IFM dimensions are NHWC
248 ifm_shape = testGen.makeShape(rank)
249
250 # Constrict the batch size?
251 if testGen.args.max_batch_size:
252 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
253
254 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800255 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700256
257 # Generate a random OFM depth
258 ofm_depth = testGen.makeShape(1)[0]
259
260 # The filter dimensions are OHWI
261 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
262
Kevin Cheng989cb052021-04-28 16:29:44 -0700263 # The bias is OC
264 bias_shape = np.asarray([ofm_depth])
265
266 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700267
268 @staticmethod
269 def tgDepthwiseConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800270 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700271
Kevin Cheng550ccc52021-03-03 11:21:43 -0800272 assert rank == 4
273 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
275 # IFM dimensions are NHWC
276 ifm_shape = testGen.makeShape(rank)
277
278 # Constrict the batch size?
279 if testGen.args.max_batch_size:
280 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
281
282 # Get the filter height/width from the operator parameters
283 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
286 # Generate a random OFM depth, but don't let it get too big because
287 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 filter_m = (
289 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
290 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700291
292 # The filter dimensions are HWCM
293 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
294
295 # The bias is M * C
296 bias_shape = np.asarray([ifm_shape[3] * filter_m])
297
298 return [ifm_shape, filter_shape, bias_shape]
299
300 @staticmethod
301 def tgFullyConnected(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800302 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700303
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700305
306 input_shape = testGen.makeShape(rank)
307 filter_oc = testGen.makeShape(1)[0]
308 filter_shape = np.asarray([filter_oc, input_shape[1]])
309
310 bias_shape = np.asarray([filter_oc])
311
312 return [input_shape, filter_shape, bias_shape]
313
314 @staticmethod
315 def tgMatmul(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800316 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700317
Kevin Cheng2d60f002021-06-09 14:18:32 -0700318 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800319 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700320
321 a_shape = testGen.makeShape(rank)
322 b_oc = testGen.makeShape(1)[0]
Kevin Cheng2d60f002021-06-09 14:18:32 -0700323 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700324
325 return [a_shape, b_shape]
326
Kevin Cheng550ccc52021-03-03 11:21:43 -0800327
Eric Kunzee5e26762020-10-13 16:11:07 -0700328class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800329 """Argument generators create exhaustive or random lists of attributes for operators that take
330 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
331 tuples where the descriptive_name is appended to the test name and the arglist is expanded
332 as arguments to the operator build function."""
333
Eric Kunzee5e26762020-10-13 16:11:07 -0700334 def __init__(self):
335 pass
336
337 @staticmethod
338 def agNone(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800339 """A trivial argument generator for operators that don't take any
340 non-tensor arguments"""
341 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700342
343 @staticmethod
344 def agAxis(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800345 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700346 axes = []
347
348 shape = shapeList[0]
349
350 for a in range(0, len(shape)):
Matthew Haddon43e37192021-07-09 14:13:02 +0100351 axes.append(("axis{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700352 return axes
353
354 @staticmethod
355 def agConv2D(testGen, opName, shapeList, dtype):
356 arg_list = []
357
358 ifm_shape = shapeList[0]
359 filter_shape = shapeList[1]
360
361 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800362 assert len(ifm_shape) == 4
363 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700364
365 maxStride = testGen.args.max_conv_stride
366 maxPadding = testGen.args.max_conv_padding + 1
367 maxDilation = testGen.args.max_conv_dilation
368
369 # Strides, padding, dilations
370 for stride in range(0, maxStride ** 2):
371 for padding in range(0, (maxPadding) ** 4):
372 for dilation in range(0, maxDilation ** 2):
373
Kevin Cheng550ccc52021-03-03 11:21:43 -0800374 s = [stride // maxStride + 1, stride % maxStride + 1]
375 p = [
376 (padding // (maxPadding * 4)) % maxPadding,
377 (padding // (maxPadding * 2)) % maxPadding,
378 (padding // (maxPadding * 1)) % maxPadding,
379 padding % maxPadding,
380 ]
381 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700382
383 # 4 padding parameters for regular conv2d
Kevin Cheng550ccc52021-03-03 11:21:43 -0800384 arg_list.append(
385 (
386 "st{}{}_pad{}{}{}{}_dilat{}{}".format(
387 s[0], s[1], p[0], p[1], p[2], p[3], d[0], d[1]
388 ),
389 [s, p, d],
390 )
391 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700392 return arg_list
393
394 @staticmethod
395 def agTransposeConv2D(testGen, opName, shapeList, dtype):
396 arg_list = []
397
398 ifm_shape = shapeList[0]
399 filter_shape = shapeList[1]
400
401 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800402 assert len(ifm_shape) == 4
403 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700404
405 maxStride = testGen.args.max_conv_stride
406 maxPadding = testGen.args.max_conv_padding + 1
407 maxDilation = testGen.args.max_conv_dilation
408
409 # Strides, padding, dilations
410 for stride in range(0, maxStride ** 2):
411 for out_padding in range(0, (maxPadding) ** 2):
412 for dilation in range(0, maxDilation ** 2):
413
Kevin Cheng550ccc52021-03-03 11:21:43 -0800414 s = [stride // maxStride + 1, stride % maxStride + 1]
415 p = [
416 (out_padding // (maxPadding * 1)) % maxPadding,
417 out_padding % maxPadding,
418 ]
419 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700420
Kevin Cheng550ccc52021-03-03 11:21:43 -0800421 oh = (
422 ifm_shape[1]
423 - filter_shape[1]
424 - (filter_shape[1] - 1) * (d[0] - 1)
425 + 2 * p[0]
426 ) // s[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700427
Kevin Cheng550ccc52021-03-03 11:21:43 -0800428 ow = (
429 ifm_shape[2]
430 - filter_shape[2]
431 - (filter_shape[2] - 1) * (d[1] - 1)
432 + 2 * p[1]
433 ) // s[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700434
435 # Output shape
Kevin Cheng550ccc52021-03-03 11:21:43 -0800436 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Eric Kunzee5e26762020-10-13 16:11:07 -0700437
Kevin Cheng550ccc52021-03-03 11:21:43 -0800438 arg_list.append(
439 (
440 "st{}{}_outpad{}{}_dilat{}{}_os{}x{}x{}x{}".format(
441 s[0],
442 s[1],
443 p[0],
444 p[1],
445 d[0],
446 d[1],
447 os[0],
448 os[1],
449 os[2],
450 os[3],
451 ),
452 [s, p, d, os],
453 )
454 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700455
456 return arg_list
457
458 @staticmethod
459 def agPad(testGen, opName, shapeList, dtype):
460 arg_list = []
461 rank = len(shapeList[0])
462
463 # Exhaustively test combinations of 0/1 padding on each side of each dimension
464 # This process might need some revision for >1 padding, but use rank**2 as a bitmask
465 # for now
466 for v in range(rank ** 2):
467
468 # Create a flat arraypadding4D
469 paddings = np.zeros((rank * 2), dtype=np.int32)
470
471 # Fill in the 1's
Kevin Cheng550ccc52021-03-03 11:21:43 -0800472 for r in range(rank * 2):
Eric Kunzee5e26762020-10-13 16:11:07 -0700473 if (v >> r) & 1:
474 paddings[r] = 1
475
476 # Reshape back to a 2D array
477 paddings = paddings.reshape((rank, 2))
478
Kevin Cheng550ccc52021-03-03 11:21:43 -0800479 arg_list.append(("pad{0:b}".format(v), [paddings]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700480
481 return arg_list
482
483 @staticmethod
484 def agPooling(testGen, opName, shapeList, dtype):
485 arg_list = []
486
487 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800488 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700489
490 maxStride = testGen.args.max_pooling_stride
491 maxKernel = testGen.args.max_pooling_kernel
492 maxPadding = testGen.args.max_pooling_padding + 1
493
494 for kernel in range(0, maxKernel ** 2):
495 for stride in range(0, maxStride ** 2):
496 for padding in range(0, maxPadding ** 4):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800497 s = [stride // maxStride + 1, stride % maxStride + 1]
498 k = [(kernel // maxKernel) + 2, (kernel % maxKernel) + 2]
499 p = [
500 (padding // (maxPadding * 4)) % maxPadding,
501 (padding // (maxPadding * 2)) % maxPadding,
502 (padding // (maxPadding * 1)) % maxPadding,
503 padding % maxPadding,
504 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700505
Kevin Cheng550ccc52021-03-03 11:21:43 -0800506 arg_list.append(
507 (
508 "st{}{}_kern{}{}_pad{}{}{}{}".format(
509 s[0], s[1], k[0], k[1], p[0], p[1], p[2], p[3]
510 ),
511 [k, s, p],
512 )
513 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700514 return arg_list
515
516 @staticmethod
517 def agCast(testGen, opName, shapeList, inDtype):
518 arg_list = []
519
520 # Enumerate the output types here
521 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800522 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700523 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800524 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700525 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800526 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700527 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800528 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700529 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800530 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700531 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800532 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700533
534 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800535 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700536
537 return arg_list
538
539 @staticmethod
540 def agRescale(testGen, opName, shapeList, inDtype):
541 arg_list = []
542
543 # Enumerate the output types here
Kevin Cheng550ccc52021-03-03 11:21:43 -0800544 for dtype in [DType.INT8, DType.INT16, DType.INT32]:
545 for scale32 in [False, True]:
546 for double_round in [False, True]:
547 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700548
549 if inDtype == DType.INT48 and scale32:
550 # Illegal condition. Must be scale32=False
551 continue
552
Kevin Cheng550ccc52021-03-03 11:21:43 -0800553 arg_list.append(
554 (
555 "out{}_sc{}_dr{}_pc{}".format(
556 DTypeNames[dtype],
557 int(scale32),
558 int(double_round),
559 int(per_channel),
560 ),
561 [dtype, scale32, double_round, per_channel],
562 )
563 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700564
565 return arg_list
566
Kevin Chengaee1fac2020-11-11 13:54:06 -0800567 @staticmethod
568 def agMul(testGen, opName, shapeList, dtype):
569 arg_list = []
570
571 if dtype is DType.INT32:
572 for p in range(testGen.args.num_rand_permutations):
573
574 shift = testGen.randInt(0, 32)
575
Kevin Cheng550ccc52021-03-03 11:21:43 -0800576 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800577 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100578 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800579
580 return arg_list
581
582 @staticmethod
583 def agArithmeticRightShift(testGen, opName, shapeList, dtype):
584 arg_list = []
585
Kevin Cheng550ccc52021-03-03 11:21:43 -0800586 arg_list.append(("roundTrue", [True]))
587 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800588
589 return arg_list
590
Eric Kunzee5e26762020-10-13 16:11:07 -0700591 # Helper function for reshape. Gets some factors of a larger number.
592 @staticmethod
593 def getFactors(val, start=1):
594 factors = []
595
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100596 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700597 if (val % i) == 0:
598 factors.append(i)
599
600 return factors
601
602 @staticmethod
603 def agReshape(testGen, opName, shapeList, dtype):
604 arg_list = []
605
606 origShape = shapeList[0]
607
608 totalElements = 1
609 for s in origShape:
610 totalElements *= s
611
612 # This code is NOT fast. Fortunately, the numbers are fairly small.
613 factors = TosaArgGen.getFactors(totalElements)
614
615 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100616 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800617 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700618 continue
619
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100620 found = True
621 # escape_counter breaks while loop if it continues on for too long
622 escape_counter = 0
623 while found:
624 newShape = []
625 # Generate newShape ensuring it isn't a duplicate
626 remainingElements = totalElements
627 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100628 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100629 # pick rank-1 factors
630 newShape.append(shuffledFactors[0])
631 remainingElements = remainingElements // shuffledFactors[0]
632 shuffledFactors = testGen.rng.permutation(
633 TosaArgGen.getFactors(remainingElements)
634 )
635 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700636
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100637 # Toss in a -1 sometimes
638 minusOne = testGen.randInt(0, newRank * 4)
639 if minusOne < newRank:
640 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700641
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100642 # Check for duplicates
643 found = False
644 for name, other_shape in arg_list:
645 if other_shape[0] == newShape:
646 found = True
647 break
648
649 escape_counter += 1
650 if escape_counter >= 100:
651 break
652
653 if not found:
654 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700655
656 return arg_list
657
Eric Kunzee5e26762020-10-13 16:11:07 -0700658 @staticmethod
659 def agTranspose(testGen, opName, shapeList, dtype):
660 arg_list = []
661
662 ifm_shape = shapeList[0]
663
Jeremy Johnsona6185572021-06-21 15:55:35 +0100664 # Get all permutations
665 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700666
Jeremy Johnsona6185572021-06-21 15:55:35 +0100667 # Limit to possible permutations from shape dimension or argument setting
668 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700669
Jeremy Johnsona6185572021-06-21 15:55:35 +0100670 # Get random permutation generator that uses all permutations
671 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700672
Jeremy Johnsona6185572021-06-21 15:55:35 +0100673 # Create list of required amount of permutations
674 arg_list = [("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit)]
Eric Kunzee5e26762020-10-13 16:11:07 -0700675 return arg_list
676
677 @staticmethod
678 def agSlice(testGen, opName, shapeList, dtype):
679 arg_list = []
680
681 ifm_shape = shapeList[0]
682 rank = len(ifm_shape)
683
684 for p in range(testGen.args.num_rand_permutations):
685 begin = []
686 size = []
687
Kevin Cheng550ccc52021-03-03 11:21:43 -0800688 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700689
690 for i in range(rank):
691 if ifm_shape[i] > 1:
692 begin.append(testGen.randInt(0, ifm_shape[i]))
693 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
694
695 # Invalid slice size?
696 if size[i] == 0:
697 valid = False
698 else:
699 begin.append(0)
700 size.append(1)
701
702 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800703 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700704 return arg_list
705
706 @staticmethod
707 def agTile(testGen, opName, shapeList, dtype):
708 arg_list = []
709
710 ifm_shape = shapeList[0]
711 rank = len(ifm_shape)
712
713 for p in range(testGen.args.num_rand_permutations):
714
715 # Pick a few random, but small multiple values
716 # because otherwise this has a tendency to generate
717 # enormous tensors
718 multiples = []
719 for i in range(rank):
720 multiples.append(testGen.randInt(1, 4))
721
Kevin Cheng550ccc52021-03-03 11:21:43 -0800722 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700723
724 return arg_list
725
726 @staticmethod
727 def agResize(testGen, opName, shapeList, dtype):
728 arg_list = []
729
730 ifm_shape = shapeList[0]
731
732 for m in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
733
734 # Exclude illegal {mode, type} configurations. Pick legal output types
735 if m == ResizeMode.NEAREST and dtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800736 outputDTypeList = [DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700737 elif m == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800738 outputDTypeList = [DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -0700739 elif m == ResizeMode.BILINEAR and dtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800740 outputDTypeList = [DType.INT8]
Eric Kunzee5e26762020-10-13 16:11:07 -0700741 elif m == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800742 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800743 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800744 outputDTypeList = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700745 else:
746 continue
747
748 for outputDType in outputDTypeList:
749 for perm in range(testGen.args.num_rand_permutations):
750
751 # Randomly generate legal output dimensions and shift
752 # and then compute the stride and offset based on them
Kevin Cheng550ccc52021-03-03 11:21:43 -0800753 output_dims = [testGen.randInt(1), testGen.randInt(1)]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800754 in_center_h = (ifm_shape[1] - 1) / 2.0
755 in_center_w = (ifm_shape[2] - 1) / 2.0
756 out_center_h = (output_dims[0] - 1) / 2.0
757 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700758
Kevin Cheng77d0f762020-11-24 10:26:32 -0800759 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
760 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
761 fp_offset_y = in_center_h - fp_stride_y * out_center_h
762 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700763
Kevin Cheng77d0f762020-11-24 10:26:32 -0800764 if outputDType == DType.FLOAT:
765 shift = 0
766 stride = [0, 0]
767 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800768 stride_fp = [fp_stride_y, fp_stride_x]
769 offset_fp = [fp_offset_y, fp_offset_x]
770 arg_list.append(
771 (
772 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
773 m,
774 output_dims[0],
775 output_dims[1],
776 testGen.typeStr(outputDType),
777 stride_fp[0],
778 stride_fp[1],
779 offset_fp[0],
780 offset_fp[1],
781 ),
782 [
783 m,
784 stride,
785 offset,
786 shift,
787 stride_fp,
788 offset_fp,
789 output_dims,
790 dtype,
791 outputDType,
792 ],
793 )
794 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800795 else:
796 shift = 11
797 unit = float(1 << shift)
798 stride_y = int(round(fp_stride_y * unit))
799 stride_x = int(round(fp_stride_x * unit))
800 offset_y = int(round(fp_offset_y * unit))
801 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700802
Kevin Cheng550ccc52021-03-03 11:21:43 -0800803 while (
804 stride_y >= 32768
805 or stride_x >= 32768
806 or offset_y >= 32768
807 or offset_x >= 32768
808 or offset_y < -32768
809 or offset_x < -32768
810 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -0800811 shift = shift - 1
812 unit = float(1 << shift)
813 stride_y = int(round(fp_stride_y * unit))
814 stride_x = int(round(fp_stride_x * unit))
815 offset_y = int(round(fp_offset_y * unit))
816 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700817
Kevin Cheng550ccc52021-03-03 11:21:43 -0800818 stride = [stride_y, stride_x]
819 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800820
821 stride_fp = [0.0, 0.0]
822 offset_fp = [0.0, 0.0]
823
Kevin Cheng550ccc52021-03-03 11:21:43 -0800824 arg_list.append(
825 (
826 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
827 m,
828 shift,
829 output_dims[0],
830 output_dims[1],
831 testGen.typeStr(outputDType),
832 stride[0],
833 stride[1],
834 offset[0],
835 offset[1],
836 ),
837 [
838 m,
839 stride,
840 offset,
841 shift,
842 stride_fp,
843 offset_fp,
844 output_dims,
845 dtype,
846 outputDType,
847 ],
848 )
849 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700850
851 return arg_list
852
853 def agCondIf(testGen, opName, shapeList, dtype):
854 # CondIf generates the condition values here.
855 # Convert to tensors in the build function, along with the
856 # then and else blocks
857 arg_list = []
858
859 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800860 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700861
862 return arg_list
863
864 def agWhileLoop(testGen, opName, shapeList, dtype):
865 # While loop: 0 iterations, 1, more than 1
866 arg_list = []
867
868 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800869 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700870
871 return arg_list
872
Kevin Cheng550ccc52021-03-03 11:21:43 -0800873
Eric Kunzee5e26762020-10-13 16:11:07 -0700874class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +0100875 # Maximum rank of tensor supported by test generator.
876 TOSA_TENSOR_MAX_RANK = 6
877
Eric Kunzee5e26762020-10-13 16:11:07 -0700878 def __init__(self, args):
879 self.args = args
880 self.basePath = args.output_dir
881 self.random_seed = args.random_seed
882 self.ser = None
883 self.rng = np.random.default_rng(self.random_seed)
884 self.createDynamicOpLists()
885 self.initOpListDefaults()
886 self.quantGen = TosaQuantGen()
887 # Force makeShape to do a specific starting shape
888 self.targetted_shape = None
889
890 def createSerializer(self, opName, testPath):
891 self.testPath = os.path.join(opName, testPath)
892
893 fullPath = os.path.join(self.basePath, self.testPath)
894 os.makedirs(fullPath, exist_ok=True)
895 self.ser = ts.TosaSerializer(fullPath)
896
897 def getSerializer(self):
898 return self.ser
899
900 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800901 with open(
902 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
903 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700904 fd.write(self.ser.serialize())
905
Kevin Cheng550ccc52021-03-03 11:21:43 -0800906 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
907 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700908
909 def getRandTensor(self, shape, dtype):
910 RAND_SHIFT_FACTOR = 0.5
911 RAND_SCALE_FACTOR = 4.0
912
913 if dtype == DType.BOOL:
914 np_dt = np.bool
915 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700916 elif dtype == DType.INT4:
917 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
918 elif dtype == DType.INT8:
919 return np.int32(self.rng.integers(low=-127, high=128, size=shape))
920 elif dtype == DType.INT16:
921 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
922 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800923 return np.int32(
924 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
925 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700926 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800927 return np.int64(
928 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
929 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700930 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800931 return np.float32(
932 self.rng.random(size=shape) - RAND_SHIFT_FACTOR * RAND_SCALE_FACTOR
933 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700934 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800935 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700936
Kevin Cheng989cb052021-04-28 16:29:44 -0700937 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700938 placeholders = []
939
Kevin Cheng989cb052021-04-28 16:29:44 -0700940 assert len(shape_list) == len(dtype_list)
941
942 for idx, shape in enumerate(shape_list):
943 arr = self.getRandTensor(shape, dtype_list[idx])
944 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700945
946 return placeholders
947
Kevin Cheng989cb052021-04-28 16:29:44 -0700948 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700949 consts = []
950
Kevin Cheng989cb052021-04-28 16:29:44 -0700951 assert len(shape_list) == len(dtype_list)
952
953 for idx, shape in enumerate(shape_list):
954 arr = self.getRandTensor(shape, dtype_list[idx])
955 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700956
957 return consts
958
959 def makeShape(self, rank):
960 if self.targetted_shape:
961 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800962 return np.int32(
963 self.rng.integers(
964 low=self.args.tensor_shape_range[0],
965 high=self.args.tensor_shape_range[1],
966 size=rank,
967 )
968 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700969
970 def setTargetShape(self, shape):
971 self.targetted_shape = shape
972
973 def randInt(self, low=0, high=256):
974 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
975
976 def getRandNumberDType(self, dtype):
977 if dtype == DType.FLOAT:
978 return self.rng.random()
979 elif dtype == DType.BOOL:
980 return self.rng.choice([False, True])
981 elif dtype == DType.INT4:
982 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700983 elif dtype == DType.INT8:
984 low, high = (-127, 128)
985 elif dtype == DType.INT16:
986 low, high = (-32768, 32768)
987 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800988 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700989 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800990 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700991 # Special size
992 return np.int64(self.rng.integers(low, high, size=1))[0]
993 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800994 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700995
996 return np.int32(self.rng.integers(low, high, size=1))[0]
997
998 def shapeStr(self, shape):
999
1000 sStr = []
1001 # Convert to strings
1002 for i in shape:
1003 sStr.append(str(i))
1004
Kevin Cheng550ccc52021-03-03 11:21:43 -08001005 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001006
1007 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07001008 if isinstance(t, list):
1009 assert len(t) >= 2
1010 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001011 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001012 if t == DType.BOOL:
1013 return "b"
1014 elif t == DType.INT4:
1015 return "i4"
1016 elif t == DType.INT8:
1017 return "i8"
1018 elif t == DType.UINT8:
1019 return "u8"
1020 elif t == DType.INT16:
1021 return "i16"
1022 elif t == DType.INT32:
1023 return "i32"
1024 elif t == DType.INT48:
1025 return "i48"
1026 elif t == DType.FLOAT:
1027 return "float"
1028 else:
1029 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001030
1031 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001032 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08001033 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07001034 return 4
1035 elif t == DType.INT8:
1036 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08001037 elif t == DType.UINT8:
1038 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07001039 elif t == DType.INT16:
1040 return 16
1041 elif t == DType.INT32:
1042 return 32
1043 elif t == DType.INT48:
1044 return 48
1045 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001046 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001047
1048 # Argument generators
1049 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
1050 # Where the string descriptor is used to generate the test name and
1051 # The build_fcn_arg_list is expanded and passed to the operator test
1052 # build function
1053
Kevin Cheng550ccc52021-03-03 11:21:43 -08001054 def build_unary(self, op, a, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001055 result_tens = OutputShaper.unaryOp(self.ser, a)
1056 self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
1057 return result_tens
1058
1059 def build_binary_broadcast(self, op, a, b):
1060 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1061 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1062 return result_tens
1063
1064 def build_binary_nonbroadcast(self, op, a, b):
1065 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
1066 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1067 return result_tens
1068
Kevin Chengaee1fac2020-11-11 13:54:06 -08001069 def build_arithmetic_right_shift(self, op, a, b, round):
1070 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1071
1072 attr = ts.TosaSerializerAttribute()
1073 attr.ArithmeticRightShiftAttribute(round)
1074
1075 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1076 return result_tens
1077
1078 def build_mul(self, op, a, b, shift):
Eric Kunzee5e26762020-10-13 16:11:07 -07001079 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1080
1081 # Special for multiply:
1082 # Force the result to INT32 for INT types
1083 if a.dtype != DType.FLOAT:
1084 result_tens.setDtype(DType.INT32)
1085
Kevin Chengaee1fac2020-11-11 13:54:06 -08001086 attr = ts.TosaSerializerAttribute()
1087 attr.MulAttribute(shift)
1088
1089 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001090 return result_tens
1091
1092 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001093 # Constant size depending on type, random values
1094 if a.dtype == DType.INT16:
1095 table_dtype = DType.INT16
1096 table_arr = self.getRandTensor([513], table_dtype)
1097 else:
1098 assert a.dtype == DType.INT8
1099 table_dtype = DType.INT8
1100 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001101
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001102 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
1103 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001104 self.ser.addOperator(op, [a.name, table_tens.name], [result_tens.name], None)
1105
1106 return result_tens
1107
1108 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07001109 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
1110 self.ser.addOperator(op, [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001111 return result_tens
1112
1113 def build_comparison(self, op, a, b):
1114 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
1115 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1116 return result_tens
1117
1118 def build_argmax(self, op, a, axis):
1119 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
1120
1121 attr = ts.TosaSerializerAttribute()
1122 attr.AxisAttribute(axis)
1123
1124 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1125 return result_tens
1126
Kevin Cheng550ccc52021-03-03 11:21:43 -08001127 def build_pool2d(self, op, input, kernel, stride, pad, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001128 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
1129
1130 attr = ts.TosaSerializerAttribute()
1131 attr.Pool2dAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07001132
1133 self.ser.addOperator(op, [input.name], [result_tens.name], attr, qinfo)
1134 return result_tens
1135
1136 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001137 assert len(padding) == 4
1138 result_tens = OutputShaper.conv2dOp(
1139 self.ser, ifm, filter, strides, padding, dilations
1140 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001141
1142 attr = ts.TosaSerializerAttribute()
1143 attr.Conv2dAttribute(padding, strides, dilations)
1144
Kevin Cheng550ccc52021-03-03 11:21:43 -08001145 self.ser.addOperator(
1146 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1147 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001148 return result_tens
1149
Kevin Cheng550ccc52021-03-03 11:21:43 -08001150 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07001151 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001152 ):
1153 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07001154 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
1155
1156 attr = ts.TosaSerializerAttribute()
1157 attr.TransposeConv2DAttribute(outpad, stride, dilation, output_shape)
1158
Kevin Cheng550ccc52021-03-03 11:21:43 -08001159 self.ser.addOperator(
Kevin Cheng989cb052021-04-28 16:29:44 -07001160 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001161 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001162 return result_tens
1163
Kevin Cheng550ccc52021-03-03 11:21:43 -08001164 def build_depthwise_conv2d(
1165 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
1166 ):
1167 result_tens = OutputShaper.depthwiseConv2dOp(
1168 self.ser, ifm, filter, strides, padding, dilations
1169 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001170
1171 attr = ts.TosaSerializerAttribute()
1172 attr.Conv2dAttribute(padding, strides, dilations)
1173
Kevin Cheng550ccc52021-03-03 11:21:43 -08001174 self.ser.addOperator(
1175 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1176 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001177 return result_tens
1178
1179 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
1180 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
1181
Kevin Cheng550ccc52021-03-03 11:21:43 -08001182 self.ser.addOperator(
1183 op, [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
1184 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001185 return result_tens
1186
1187 def build_matmul(self, op, a, b, qinfo):
1188 result_tens = OutputShaper.matmulOp(self.ser, a, b)
1189 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], None, qinfo)
1190 return result_tens
1191
1192 def build_reduce(self, op, a, axis):
1193 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
1194
1195 attr = ts.TosaSerializerAttribute()
1196 attr.AxisAttribute(axis)
1197
1198 self.ser.addOperator(op, [a.name], result_tens.name, attr)
1199 return result_tens
1200
1201 def build_clamp(self, op, a):
1202 result_tens = OutputShaper.unaryOp(self.ser, a)
1203
1204 attr = ts.TosaSerializerAttribute()
1205
1206 # Get two random ints
1207 v = [self.randInt(), self.randInt()]
1208
1209 if a.dtype == DType.FLOAT:
1210 attr.ClampAttribute(0, 0, min(v), max(v))
1211 else:
1212 attr.ClampAttribute(min(v), max(v), 0, 0)
1213
1214 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1215 return result_tens
1216
1217 def build_leaky_relu(self, op, a):
1218 result_tens = OutputShaper.unaryOp(self.ser, a)
1219 attr = ts.TosaSerializerAttribute()
1220
1221 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1222
1223 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1224 return result_tens
1225
1226 # Needs an additional type/input
1227 def build_prelu(self, op, a):
1228 result_tens = OutputShaper.unaryOp(self.ser, a)
1229
1230 self.ser.addOperator(op, [a.name], [result_tens.name])
1231 return result_tens
1232
1233 def build_relun(self, op, a):
1234 result_tens = OutputShaper.unaryOp(self.ser, a)
1235
1236 attr = ts.TosaSerializerAttribute()
1237
1238 if a.dtype == DType.FLOAT:
1239 attr.ReluNAttribute(0, self.getRandNumberDType(a.dtype))
1240 else:
1241 attr.ReluNAttribute(self.getRandNumberDType(a.dtype), 0)
1242
1243 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1244 return result_tens
1245
1246 def build_sigmoid(self, op, a):
1247 result_tens = OutputShaper.unaryOp(self.ser, a)
1248 self.ser.addOperator(op, [a.name], [result_tens.name])
1249 return result_tens
1250
1251 def build_tanh(self, op, a):
1252 result_tens = OutputShaper.unaryOp(self.ser, a)
1253 self.ser.addOperator(op, [a.name], [result_tens.name])
1254 return result_tens
1255
1256 def build_concat(self, op, a, b, axis):
1257 result_tens = OutputShaper.concatOp(self.ser, a, b, axis)
1258
1259 attr = ts.TosaSerializerAttribute()
1260 attr.AxisAttribute(axis)
1261
1262 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1263
1264 def build_pad(self, op, a, padding, qinfo):
1265 result_tens = OutputShaper.padOp(self.ser, a, padding)
1266
1267 # Need to turn the padding array into a TOSA tensor here.
1268 # This is one of the few tensor operands that does not get
1269 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08001270 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07001271
Kevin Cheng550ccc52021-03-03 11:21:43 -08001272 self.ser.addOperator(
1273 op, [a.name, padding_tens.name], [result_tens.name], None, qinfo
1274 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001275
1276 def build_reshape(self, op, a, newShape):
1277 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
1278
1279 attr = ts.TosaSerializerAttribute()
1280 attr.ReshapeAttribute(newShape)
1281
1282 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1283 return result_tens
1284
1285 def build_reverse(self, op, a, axis):
1286 result_tens = OutputShaper.unaryOp(self.ser, a)
1287
1288 attr = ts.TosaSerializerAttribute()
1289 attr.AxisAttribute(axis)
1290
1291 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1292 return result_tens
1293
1294 def build_transpose(self, op, a, perms):
1295 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
1296
Kevin Cheng550ccc52021-03-03 11:21:43 -08001297 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07001298
1299 self.ser.addOperator(op, [a.name, perms_tens.name], [result_tens.name])
1300 return result_tens
1301
1302 def build_slice(self, op, a, begin, size):
1303 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
1304
1305 attr = ts.TosaSerializerAttribute()
1306 attr.SliceAttribute(begin, size)
1307
1308 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1309 return result_tens
1310
1311 def build_tile(self, op, a, multiples):
1312 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
1313
1314 attr = ts.TosaSerializerAttribute()
1315 attr.TileAttribute(multiples)
1316
1317 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1318 return result_tens
1319
Kevin Cheng77d0f762020-11-24 10:26:32 -08001320 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07001321
1322 # Create a new indicies tensor
1323 # here with data that doesn't exceed the dimensions of the values tensor
1324
Kevin Cheng550ccc52021-03-03 11:21:43 -08001325 K = values.shape[1] # K
1326 W = self.randInt(
1327 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1328 ) # W
1329 indicies_arr = np.int32(
1330 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1331 ) # (N, W)
1332 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001333
Kevin Cheng77d0f762020-11-24 10:26:32 -08001334 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07001335
Kevin Cheng77d0f762020-11-24 10:26:32 -08001336 self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001337
1338 return result_tens
1339
Kevin Cheng77d0f762020-11-24 10:26:32 -08001340 def build_scatter(self, op, values_in, input):
1341
1342 # Create a new indicies tensor
1343 # here with data that doesn't exceed the dimensions of the values_in tensor
1344
Kevin Cheng550ccc52021-03-03 11:21:43 -08001345 K = values_in.shape[1] # K
1346 W = input.shape[1] # W
1347 indicies_arr = np.int32(
1348 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1349 ) # (N, W)
1350 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001351
1352 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
1353
Kevin Cheng550ccc52021-03-03 11:21:43 -08001354 self.ser.addOperator(
1355 op, [values_in.name, indicies.name, input.name], [result_tens.name]
1356 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001357
1358 return result_tens
1359
Kevin Cheng550ccc52021-03-03 11:21:43 -08001360 def build_resize(
1361 self,
1362 op,
1363 input,
1364 mode,
1365 stride,
1366 offset,
1367 shift,
1368 stride_fp,
1369 offset_fp,
1370 output_dims,
1371 input_dtype,
1372 output_dtype,
1373 ):
1374 result_tens = OutputShaper.resizeOp(
1375 self.ser,
1376 input,
1377 mode,
1378 stride,
1379 offset,
1380 shift,
1381 stride_fp,
1382 offset_fp,
1383 output_dims,
1384 input_dtype,
1385 output_dtype,
1386 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001387
1388 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001389
Kevin Cheng550ccc52021-03-03 11:21:43 -08001390 attr.ResizeAttribute(
1391 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
1392 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001393
1394 self.ser.addOperator(op, [input.name], [result_tens.name], attr)
1395 return result_tens
1396
1397 def build_identityn(self, op, val, val2):
1398
Kevin Cheng550ccc52021-03-03 11:21:43 -08001399 result_tens = OutputShaper.unaryOp(self.ser, val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001400 result_tens2 = OutputShaper.unaryOp(self.ser, val2)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001401 self.ser.addOperator(
1402 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1403 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001404 return result_tens
1405
1406 def build_placeholder(self, op, val):
1407 # Add an identity op to avoid warning in the reference model
1408 return self.build_unary(Op.IDENTITY, val)
1409
1410 # Type Conversion
1411 def build_cast(self, op, val, out_dtype):
1412 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1413 self.ser.addOperator(op, [val.name], [result_tens.name])
1414 return result_tens
1415
1416 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
1417 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1418
1419 if per_channel:
1420 nc = val.shape[-1]
1421 else:
1422 nc = 1
1423
1424 in_type_width = self.typeWidth(val.dtype)
1425 out_type_width = self.typeWidth(out_dtype)
1426
Kevin Cheng3a478572021-01-22 17:21:02 -08001427 if val.dtype == DType.INT8:
Kevin Cheng989cb052021-04-28 16:29:44 -07001428 input_zp = self.randInt(-128, 127)
Eric Kunzee5e26762020-10-13 16:11:07 -07001429 in_type_width = in_type_width + 1
1430 else:
1431 input_zp = 0
1432
Kevin Cheng3a478572021-01-22 17:21:02 -08001433 if out_dtype == DType.INT8:
Kevin Cheng989cb052021-04-28 16:29:44 -07001434 output_zp = self.randInt(-128, 127)
Eric Kunzee5e26762020-10-13 16:11:07 -07001435 out_type_width = out_type_width + 1
1436 else:
1437 output_zp = 0
1438
1439 # Calculate scale based on:
1440 # scale = a *(2^output_width)/(2^input_width))
1441
1442 a = np.float32(self.rng.random(size=[nc]))
1443 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1444
1445 if scale32:
1446 pass
1447 # Cap the scaling at 2^15 - 1 for scale16
1448 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1449 else:
1450 # Cap the scaling at 2^15 - 1 for scale16
1451 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1452
Kevin Cheng550ccc52021-03-03 11:21:43 -08001453 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001454
1455 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1456 shift_arr = np.int32(np.zeros(shape=[nc]))
1457
1458 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001459 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1460 scale_arr[i], scale32
1461 )
Kevin Chengaee1fac2020-11-11 13:54:06 -08001462 if shift_arr[i] < 2 or shift_arr[i] > 62:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001463 self.ser.setExpectedFailure(True, "OpRescale: invalid shift value")
Eric Kunzee5e26762020-10-13 16:11:07 -07001464
Kevin Cheng550ccc52021-03-03 11:21:43 -08001465 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07001466
1467 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001468 attr.RescaleAttribute(
1469 input_zp,
1470 output_zp,
1471 multiplier_arr,
1472 shift_arr,
1473 scale32,
1474 double_round,
1475 per_channel,
1476 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001477
1478 self.ser.addOperator(op, [val.name], [result_tens.name], attr)
1479 return result_tens
1480
1481 def build_cond_if_const(self, op, then_tens, else_tens, cond):
1482 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1483 # (except for the generated shap) and the condition. Build Then/Else blocks
1484 # and fill them with const nodes for the body.
1485
1486 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001487 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001488
1489 # Make then/else tensors
1490 out_shape = then_tens.shape
1491 then_arr = np.int32(self.rng.integers(0, 255, size=out_shape))
1492 else_arr = np.int32(self.rng.integers(0, 255, size=out_shape))
1493
1494 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001495 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001496
1497 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001498 then_block = "THEN_BLOCK"
1499 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001500 attr = ts.TosaSerializerAttribute()
1501 attr.CondIfAttribute(then_block, else_block)
1502
1503 # Finally, build the op and the two blocks
1504 self.ser.addOperator(op, [cond_tens.name], [result_tens.name], attr)
1505
1506 self.ser.startBasicBlock(then_block)
1507 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001508 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001509 self.ser.addOutputTensor(then_tens)
1510
1511 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001512 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001513 self.ser.addOutputTensor(else_tens)
1514
1515 return result_tens
1516
1517 def build_cond_if_binary(self, op, a, b, cond):
1518 # For cond_if with a binary op in the then/else blocks, take a and b and
1519 # alternately add or subtract them based on the condition
1520
1521 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001522 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001523
Kevin Cheng550ccc52021-03-03 11:21:43 -08001524 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001525 self.ser.currBasicBlock.addOutput(result_tens.name)
1526
1527 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001528 then_block = "THEN_BLOCK"
1529 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001530 attr = ts.TosaSerializerAttribute()
1531 attr.CondIfAttribute(then_block, else_block)
1532
1533 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001534 self.ser.addOperator(
1535 op, [cond_tens.name, a.name, b.name], [result_tens.name], attr
1536 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001537
1538 self.ser.startBasicBlock(then_block)
1539 self.ser.addInputTensor(a)
1540 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001541 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001542 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
1543
1544 self.ser.startBasicBlock(else_block)
1545 self.ser.addInputTensor(a)
1546 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001547 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001548 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
1549
1550 return result_tens
1551
1552 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001553 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001554
Kevin Cheng550ccc52021-03-03 11:21:43 -08001555 cond_block = "COND_BLOCK"
1556 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001557
1558 attr = ts.TosaSerializerAttribute()
1559 attr.WhileLoopAttribute(cond_block, body_block)
1560
1561 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001562 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001563 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001564 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001565
1566 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001567 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1568 a_out = self.ser.addIntermediate(a.shape, a.dtype)
1569 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001570
1571 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001572 self.ser.addOperator(
1573 op,
1574 [iter.name, a.name, acc.name],
1575 [iter_out.name, a_out.name, acc_out.name],
1576 attr,
1577 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001578
1579 # COND block (input: iter, output: cond_tens )
1580 self.ser.startBasicBlock(cond_block)
1581 self.ser.addInputTensor(iter)
1582 self.ser.addInputTensor(a)
1583 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001584 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
1585 cond_tens = self.ser.addOutput([], DType.BOOL)
1586 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001587
1588 # BODY block (input: a, acc, iter, output: a, acc, iter)
1589 # Note that local intermediate tensors need to be declared here for the outputs
1590 self.ser.startBasicBlock(body_block)
1591 self.ser.addInputTensor(iter)
1592 self.ser.addInputTensor(a)
1593 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001594 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
1595 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1596 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001597 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1598 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1599 self.ser.addOutputTensor(iter_body_out)
1600 self.ser.addOutputTensor(a)
1601 self.ser.addOutputTensor(acc_body_out)
1602
1603 return acc_out
1604
Kevin Cheng550ccc52021-03-03 11:21:43 -08001605 def genOpTestList(
1606 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None
1607 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001608
1609 try:
1610 op = self.TOSA_OP_LIST[opName]
1611 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001612 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001613
1614 # Initialize a new random number generator
1615 self.rng = np.random.default_rng(self.random_seed)
1616
Kevin Cheng550ccc52021-03-03 11:21:43 -08001617 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001618
1619 # Generate the lists of arguments
Kevin Cheng550ccc52021-03-03 11:21:43 -08001620 rmin, rmax = op["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001621
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001622 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
1623 default_test_rank_range = range(1, 5)
1624
Eric Kunzee5e26762020-10-13 16:11:07 -07001625 # Test list consists of a tuple of:
1626 # (opName, testNameStr, dtype, shapeList, argumentsList)
1627 testList = []
1628
1629 if not shapeFilter:
1630 shapeFilter = [None]
1631
1632 for r in range(rmin, rmax + 1):
1633
1634 # Filter out the rank?
1635 if rankFilter is not None and r not in rankFilter:
1636 continue
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001637 if rankFilter is None and shapeFilter[0] is None and r not in default_test_rank_range:
1638 continue
Eric Kunzee5e26762020-10-13 16:11:07 -07001639
Kevin Cheng550ccc52021-03-03 11:21:43 -08001640 for t in op["types"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001641
1642 # Filter tests based on dtype?
1643 if dtypeFilter is not None:
1644 if t not in dtypeFilter:
1645 continue
1646
1647 # Create the placeholder and const tensors
1648 for shape in shapeFilter:
1649 # A None shape chooses a random shape of a given rank
1650
1651 # Filter out by rank
1652 if shape is not None and len(shape) != r:
1653 continue
1654
1655 self.setTargetShape(shape)
1656 shapeList = tgen_fcn(self, op, r)
1657
1658 shapeStr = self.shapeStr(shapeList[0])
1659 typeStr = self.typeStr(t)
1660
1661 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
1662 argList = []
1663 if agen_fcn:
1664 argList = agen_fcn(self, opName, shapeList, t)
1665 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001666 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07001667
1668 for argStr, args in argList:
1669 if argStr:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001670 testStr = "{}_{}_{}_{}".format(
1671 opName, shapeStr, typeStr, argStr
1672 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001673 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001674 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001675
1676 testList.append((opName, testStr, t, shapeList, args))
1677
1678 return testList
1679
Kevin Cheng989cb052021-04-28 16:29:44 -07001680 def serializeTest(self, opName, testStr, dtype_or_dtypeList, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07001681 try:
1682 op = self.TOSA_OP_LIST[opName]
1683 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001684 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001685
1686 # Create a serializer
1687 self.createSerializer(opName, testStr)
1688
Kevin Cheng550ccc52021-03-03 11:21:43 -08001689 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
1690 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07001691 num_operands = pCount + cCount
1692
1693 if isinstance(dtype_or_dtypeList, list):
1694 dtypeList = dtype_or_dtypeList
1695 else:
1696 dtypeList = [dtype_or_dtypeList] * (num_operands)
1697
1698 assert (
1699 len(shapeList) == num_operands
1700 ), "shapeList length {} must match number of operands {}".format(
1701 len(shapeList), num_operands
1702 )
1703 assert (
1704 len(dtypeList) == num_operands
1705 ), "dtypeList length {} must match number of operands {}".format(
1706 len(dtypeList), num_operands
1707 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001708
1709 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001710 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001711 except KeyError:
1712 qgen = None
1713
1714 # Build the random tensor operands and the test
1715 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08001716
1717 # If test is ArithmeticRightShift, force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001718 if op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
1719 assert (
1720 pCount == 2 and cCount == 0
1721 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08001722
1723 placeholders = []
1724 for idx, shape in enumerate(shapeList[:]):
1725 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07001726 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001727 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001728 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001729 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001730 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001731 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
1732 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001733 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08001734 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001735 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07001736 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001737
1738 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01001739 elif op["op"] == Op.SELECT:
1740 # Set datatype of condition tensor to boolean
1741 dtypeList[0] = DType.BOOL
1742 tens.extend(
1743 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
1744 )
1745 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001746 elif op["op"] == Op.DIV:
1747 assert (
1748 pCount == 2 and cCount == 0
1749 ), "Op.Div must have 2 placeholders, 0 consts"
1750
1751 placeholders = []
1752
1753 # Two invalid cases for Op.DIV:
1754 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07001755 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001756 while True:
1757 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
1758 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
1759
1760 if (divisor_arr == 0).any():
1761 continue
1762
Kevin Cheng47315e12021-05-13 17:41:28 -07001763 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001764 continue
1765
1766 break
1767
1768 placeholders.append(
1769 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1770 )
1771 placeholders.append(
1772 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1773 )
1774
1775 tens.extend(placeholders)
1776 elif op["op"] == Op.MUL:
1777 assert (
1778 pCount == 2 and cCount == 0
1779 ), "Op.MUL must have 2 placeholders, 0 consts"
1780
1781 if dtypeList[0] == DType.FLOAT:
1782 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
1783 else:
1784 placeholders = []
1785
1786 # Make sure multiply result in int32 range
1787 shift = testArgs[0]
1788 if dtypeList[0] == DType.INT8:
1789 num_bits = 8
1790 elif dtypeList[0] == DType.INT16:
1791 num_bits = 16
1792 elif dtypeList[0] == DType.INT32:
1793 num_bits = 32
1794 else:
1795 raise Exception("OpMul: invalid input dtype")
1796
1797 for idx, shape in enumerate(shapeList[:]):
1798 low = -(2 ** (num_bits - 1))
1799 high = (2 ** (num_bits - 1)) - 1
1800
1801 a_arr = np.int32(
1802 self.rng.integers(low=low, high=high, size=shapeList[0])
1803 )
1804 b_arr = np.int32(
1805 self.rng.integers(low=low, high=high, size=shapeList[1])
1806 )
1807
1808 i = 0
1809 while True:
1810
1811 a_arr_64 = a_arr.astype(np.int64)
1812 b_arr_64 = b_arr.astype(np.int64)
1813
1814 if shift > 0:
1815 rounding = 1 << (shift - 1)
1816 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
1817 else:
1818 result_arr = a_arr_64 * b_arr_64
1819
1820 if (result_arr > -(2 ** 31)).all() and (
1821 result_arr <= ((2 ** 31) - 1)
1822 ).all():
1823 break
1824
1825 i = i + 1
1826 a_arr = a_arr // 2
1827 b_arr = b_arr // 2
1828
1829 placeholders.append(
1830 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1831 )
1832 placeholders.append(
1833 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1834 )
1835
1836 tens.extend(placeholders)
Kevin Chengaee1fac2020-11-11 13:54:06 -08001837 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001838 tens.extend(
1839 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
1840 )
1841 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001842
1843 if qgen is not None:
Kevin Cheng989cb052021-04-28 16:29:44 -07001844 qinfo = qgen(self, op, dtypeList[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07001845 else:
1846 qinfo = None
1847
1848 try:
1849 if qinfo is not None:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001850 resultName = build_fcn(self, op["op"], *tens, *testArgs, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001851 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001852 resultName = build_fcn(self, op["op"], *tens, *testArgs)
Eric Kunzee5e26762020-10-13 16:11:07 -07001853 except TypeError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001854 print(
1855 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
1856 build_fcn, tens, testArgs
1857 )
1858 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001859 raise e
1860
1861 # Save the serialized test
Kevin Cheng550ccc52021-03-03 11:21:43 -08001862 self.serialize("test")
Eric Kunzee5e26762020-10-13 16:11:07 -07001863
1864 def createDynamicOpLists(self):
1865
1866 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng550ccc52021-03-03 11:21:43 -08001867 KERNELS = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07001868
1869 for k in KERNELS:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001870 testName = "conv2d_{}x{}".format(k[0], k[1])
1871 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
1872 self.TOSA_OP_LIST[testName]["filter"] = k
1873 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001874
Kevin Cheng550ccc52021-03-03 11:21:43 -08001875 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
1876 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1877 "depthwise_conv2d_TEMPLATE"
1878 ].copy()
1879 self.TOSA_OP_LIST[testName]["filter"] = k
1880 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001881
Kevin Cheng550ccc52021-03-03 11:21:43 -08001882 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
1883 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1884 "transpose_conv2d_TEMPLATE"
1885 ].copy()
1886 self.TOSA_OP_LIST[testName]["filter"] = k
1887 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001888
1889 # Delete any templates after having created any dynamic ops
1890 # This is a two-pass operation because it's bad practice to delete
1891 # keys from dictionaries while iterating
1892 keyList = []
1893 for k in self.TOSA_OP_LIST:
1894 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001895 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07001896 keyList.append(k)
1897 continue
1898 except KeyError:
1899 pass
1900
1901 for k in keyList:
1902 del self.TOSA_OP_LIST[k]
1903
1904 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001905 """Fill in default fields for ops if they aren't already specified.
1906 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07001907 for op in self.TOSA_OP_LIST:
1908
1909 # Required fields
1910 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001911 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001912 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001913 raise Exception(
1914 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
1915 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001916
1917 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001918 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001919 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001920 raise Exception(
1921 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
1922 op
1923 )
1924 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001925
1926 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001927 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001928 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001929 raise Exception(
1930 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
1931 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001932
1933 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001934 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001935 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001936 raise Exception(
1937 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
1938 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001939
1940 # Put in default rank range, if missing
1941 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001942 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001943 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001944 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07001945
1946 # Tensor operator list
1947 # 'op': op name
1948 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08001949 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
1950 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001951 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
1952 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08001953 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07001954
Kevin Cheng550ccc52021-03-03 11:21:43 -08001955 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
1956 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07001957
Kevin Cheng550ccc52021-03-03 11:21:43 -08001958 TYPE_BOOL = [DType.BOOL]
1959 TYPE_FI32 = [DType.FLOAT, DType.INT32]
1960 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
1961 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07001962
Kevin Cheng550ccc52021-03-03 11:21:43 -08001963 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07001964
Kevin Cheng989cb052021-04-28 16:29:44 -07001965 TYPE_CONV2D = [
1966 [DType.INT8, DType.INT8, DType.INT32],
1967 [DType.INT16, DType.INT8, DType.INT48],
1968 DType.FLOAT,
1969 ]
1970
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001971 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07001972
1973 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08001974 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08001975 "argmax": {
1976 "op": Op.ARGMAX,
1977 "operands": (1, 0),
1978 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
1979 "types": TYPE_NARROW_INT_FP,
1980 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001981 "avg_pool2d": {
1982 "op": Op.AVG_POOL2D,
1983 "operands": (1, 0),
1984 "rank": (4, 4),
1985 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
1986 "qgen": TosaQuantGen.qgUnary,
1987 "types": TYPE_NARROW_INT_FP,
1988 },
Eric Kunzee5e26762020-10-13 16:11:07 -07001989 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08001990 "conv2d_TEMPLATE": {
1991 "op": Op.CONV2D,
1992 "operands": (1, 2),
1993 "rank": (4, 4),
1994 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
1995 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07001996 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001997 "template": True,
1998 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001999 # Conv3d TBD
Eric Kunzee5e26762020-10-13 16:11:07 -07002000 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002001 "depthwise_conv2d_TEMPLATE": {
2002 "op": Op.DEPTHWISE_CONV2D,
2003 "operands": (1, 2),
2004 "filter": [1, 1],
2005 "rank": (4, 4),
2006 "build_fcn": (
2007 build_depthwise_conv2d,
2008 TosaTensorGen.tgDepthwiseConv2D,
2009 TosaArgGen.agConv2D,
2010 ),
2011 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002012 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002013 "template": True,
2014 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002015 "fully_connected": {
2016 "op": Op.FULLY_CONNECTED,
2017 "operands": (1, 2),
2018 "rank": (2, 2),
2019 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
2020 "qgen": TosaQuantGen.qgConv,
2021 "types": TYPE_CONV2D,
2022 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002023 "matmul": {
2024 "op": Op.MATMUL,
2025 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002026 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08002027 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
2028 "qgen": TosaQuantGen.qgMatmul,
2029 "types": TYPE_NARROW_INT_FP,
2030 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002031 "max_pool2d": {
2032 "op": Op.MAX_POOL2D,
2033 "operands": (1, 0),
2034 "rank": (4, 4),
2035 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2036 "types": TYPE_NARROW_INT_FP,
2037 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002038 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002039 "transpose_conv2d_TEMPLATE": {
2040 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002041 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002042 "rank": (4, 4),
2043 "build_fcn": (
2044 build_transpose_conv2d,
2045 TosaTensorGen.tgTransposeConv2D,
2046 TosaArgGen.agTransposeConv2D,
2047 ),
2048 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002049 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002050 "template": True,
2051 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002052 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002053 "clamp": {
2054 "op": Op.CLAMP,
2055 "operands": (1, 0),
2056 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
2057 "types": TYPE_NARROW_INT_FP,
2058 },
2059 "relun": {
2060 "op": Op.RELUN,
2061 "operands": (1, 0),
2062 "build_fcn": (build_relun, TosaTensorGen.tgBasic, None),
2063 "types": TYPE_FI32,
2064 },
2065 "sigmoid": {
2066 "op": Op.SIGMOID,
2067 "operands": (1, 0),
2068 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
2069 "types": TYPE_FP,
2070 },
2071 "tanh": {
2072 "op": Op.TANH,
2073 "operands": (1, 0),
2074 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
2075 "types": TYPE_FP,
2076 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002077 # Elementwise Binary Operators
2078 "add": {
2079 "op": Op.ADD,
2080 "operands": (2, 0),
2081 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2082 "types": TYPE_FI32,
2083 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002084 "arithmetic_right_shift": {
2085 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2086 "operands": (2, 0),
2087 "build_fcn": (
2088 build_arithmetic_right_shift,
2089 TosaTensorGen.tgBroadcastFuzz,
2090 TosaArgGen.agArithmeticRightShift,
2091 ),
2092 "types": TYPE_INT,
2093 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002094 "bitwise_and": {
2095 "op": Op.BITWISE_AND,
2096 "operands": (2, 0),
2097 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2098 "types": TYPE_INT,
2099 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002100 "bitwise_or": {
2101 "op": Op.BITWISE_OR,
2102 "operands": (2, 0),
2103 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2104 "types": TYPE_INT,
2105 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002106 "bitwise_xor": {
2107 "op": Op.BITWISE_XOR,
2108 "operands": (2, 0),
2109 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2110 "types": TYPE_INT,
2111 },
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002112 "div": {
2113 "op": Op.DIV,
2114 "operands": (2, 0),
2115 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2116 "types": [DType.INT32],
2117 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002118 "logical_and": {
2119 "op": Op.LOGICAL_AND,
2120 "operands": (2, 0),
2121 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2122 "types": TYPE_BOOL,
2123 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002124 "logical_left_shift": {
2125 "op": Op.LOGICAL_LEFT_SHIFT,
2126 "operands": (2, 0),
2127 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2128 "types": TYPE_INT,
2129 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002130 "logical_right_shift": {
2131 "op": Op.LOGICAL_RIGHT_SHIFT,
2132 "operands": (2, 0),
2133 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2134 "types": TYPE_INT,
2135 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002136 "logical_or": {
2137 "op": Op.LOGICAL_OR,
2138 "operands": (2, 0),
2139 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2140 "types": TYPE_BOOL,
2141 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002142 "logical_xor": {
2143 "op": Op.LOGICAL_XOR,
2144 "operands": (2, 0),
2145 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2146 "types": TYPE_BOOL,
2147 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002148 "maximum": {
2149 "op": Op.MAXIMUM,
2150 "operands": (2, 0),
2151 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2152 "types": TYPE_FI32,
2153 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002154 "minimum": {
2155 "op": Op.MINIMUM,
2156 "operands": (2, 0),
2157 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2158 "types": TYPE_FI32,
2159 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002160 "mul": {
2161 "op": Op.MUL,
2162 "operands": (2, 0),
2163 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
2164 "types": TYPE_INT_FP,
2165 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002166 "pow": {
2167 "op": Op.POW,
2168 "operands": (2, 0),
2169 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
2170 "types": TYPE_FP,
2171 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002172 "sub": {
2173 "op": Op.SUB,
2174 "operands": (2, 0),
2175 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2176 "types": TYPE_FI32,
2177 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002178 "table": {
2179 "op": Op.TABLE,
2180 # Use the automatic generation functions to create the input array
2181 # but create the table tensor in the build function, as it may be
2182 # a different type from the input
2183 "operands": (1, 0),
2184 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002185 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08002186 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002187 # Elementwise Unary operators
2188 "abs": {
2189 "op": Op.ABS,
2190 "operands": (1, 0),
2191 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2192 "types": TYPE_FI32,
2193 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002194 "bitwise_not": {
2195 "op": Op.BITWISE_NOT,
2196 "operands": (1, 0),
2197 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2198 "types": TYPE_INT,
2199 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002200 "ceil": {
2201 "op": Op.CEIL,
2202 "operands": (1, 0),
2203 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2204 "types": TYPE_FP,
2205 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002206 "clz": {
2207 "op": Op.CLZ,
2208 "operands": (1, 0),
2209 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2210 "types": [DType.INT32],
2211 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002212 "exp": {
2213 "op": Op.EXP,
2214 "operands": (1, 0),
2215 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2216 "types": TYPE_FP,
2217 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002218 "floor": {
2219 "op": Op.FLOOR,
2220 "operands": (1, 0),
2221 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2222 "types": TYPE_FP,
2223 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002224 "log": {
2225 "op": Op.LOG,
2226 "operands": (1, 0),
2227 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2228 "types": TYPE_FP,
2229 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002230 "logical_not": {
2231 "op": Op.LOGICAL_NOT,
2232 "operands": (1, 0),
2233 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2234 "types": TYPE_BOOL,
2235 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002236 "negate": {
2237 "op": Op.NEGATE,
2238 "operands": (1, 0),
2239 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2240 "qgen": TosaQuantGen.qgUnary,
2241 "types": TYPE_INT_FP,
2242 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002243 "reciprocal": {
2244 "op": Op.RECIPROCAL,
2245 "operands": (1, 0),
2246 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2247 "types": TYPE_FP,
2248 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002249 "rsqrt": {
2250 "op": Op.RSQRT,
2251 "operands": (1, 0),
2252 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2253 "types": TYPE_FP,
2254 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002255 # Elementwise Ternary operators
2256 "select": {
2257 "op": Op.SELECT,
2258 "operands": (3, 0),
2259 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
2260 "types": TYPE_FIB,
2261 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002262 # Comparison operators
2263 "equal": {
2264 "op": Op.EQUAL,
2265 "operands": (2, 0),
2266 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2267 "types": TYPE_FI32,
2268 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002269 "greater_equal": {
2270 "op": Op.GREATER_EQUAL,
2271 "operands": (2, 0),
2272 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2273 "types": TYPE_FI32,
2274 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002275 "greater": {
2276 "op": Op.GREATER,
2277 "operands": (2, 0),
2278 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2279 "types": TYPE_FI32,
2280 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002281 # Reduction operators
2282 "reduce_all": {
2283 "op": Op.REDUCE_ALL,
2284 "operands": (1, 0),
2285 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2286 "types": TYPE_BOOL,
2287 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002288 "reduce_any": {
2289 "op": Op.REDUCE_ANY,
2290 "operands": (1, 0),
2291 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2292 "types": TYPE_BOOL,
2293 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002294 "reduce_max": {
2295 "op": Op.REDUCE_MAX,
2296 "operands": (1, 0),
2297 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2298 "types": TYPE_INT_FP,
2299 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002300 "reduce_min": {
2301 "op": Op.REDUCE_MAX,
2302 "operands": (1, 0),
2303 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2304 "types": TYPE_INT_FP,
2305 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002306 "reduce_product": {
2307 "op": Op.REDUCE_PRODUCT,
2308 "operands": (1, 0),
2309 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2310 "types": TYPE_FP,
2311 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002312 "reduce_sum": {
2313 "op": Op.REDUCE_SUM,
2314 "operands": (1, 0),
2315 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2316 "types": TYPE_FI32,
2317 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002318 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002319 "concat": {
2320 "op": Op.CONCAT,
2321 "operands": (2, 0),
2322 "build_fcn": (build_concat, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2323 "types": TYPE_FIB,
2324 },
2325 "pad": {
2326 "op": Op.PAD,
2327 "operands": (1, 0),
2328 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
2329 "qgen": TosaQuantGen.qgPad,
2330 "types": TYPE_FIB,
2331 },
2332 "reshape": {
2333 "op": Op.RESHAPE,
2334 "operands": (1, 0),
2335 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
2336 "types": TYPE_FIB,
2337 },
2338 "reverse": {
2339 "op": Op.REVERSE,
2340 "operands": (1, 0),
2341 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2342 "types": TYPE_FIB,
2343 },
2344 "slice": {
2345 "op": Op.SLICE,
2346 "operands": (1, 0),
2347 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
2348 "types": TYPE_FIB,
2349 },
2350 "tile": {
2351 "op": Op.TILE,
2352 "operands": (1, 0),
2353 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
2354 "types": TYPE_FIB,
2355 },
2356 "transpose": {
2357 "op": Op.TRANSPOSE,
2358 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01002359 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002360 "build_fcn": (
2361 build_transpose,
2362 TosaTensorGen.tgBasic,
2363 TosaArgGen.agTranspose,
2364 ),
2365 "types": TYPE_FIB,
2366 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002367 # Data nodes
2368 "const": {
2369 "op": Op.CONST,
2370 "operands": (1, 0),
2371 "build_fcn": (build_placeholder, TosaTensorGen.tgBasic, None),
2372 "types": TYPE_FIB,
2373 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002374 "identity": {
2375 "op": Op.IDENTITY,
2376 "operands": (1, 0),
2377 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2378 "types": TYPE_FIB,
2379 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002380 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08002381 "gather": {
2382 "op": Op.GATHER,
2383 # Only specify 'values' tensor here. 'indices' is generated in op building stage
2384 "operands": (1, 0),
2385 "rank": (3, 3),
2386 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
2387 "types": TYPE_INT_FP,
2388 },
2389 "scatter": {
2390 "op": Op.SCATTER,
2391 # Only specify 'values_in' tensor here.
2392 #'indices' and 'input' are generated in op building stage
2393 "operands": (2, 0),
2394 "rank": (3, 3),
2395 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
2396 "types": TYPE_INT_FP,
2397 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002398 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08002399 "resize": {
2400 "op": Op.RESIZE,
2401 "operands": (1, 0),
2402 "rank": (4, 4),
2403 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
2404 "types": [DType.INT8, DType.INT16, DType.FLOAT],
2405 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002406 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08002407 "cast": {
2408 "op": Op.CAST,
2409 "operands": (1, 0),
2410 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
2411 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
2412 },
2413 "rescale": {
2414 "op": Op.RESCALE,
2415 "operands": (1, 0),
2416 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
2417 "types": [DType.INT8, DType.INT16, DType.INT32, DType.INT48],
2418 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002419 # Custom
2420 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08002421 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07002422 # Two varients of cond_if, one that generates one of two constant tensors (no
2423 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
2424 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002425 "cond_if_const": {
2426 "op": Op.COND_IF,
2427 "operands": (0, 2),
2428 "build_fcn": (
2429 build_cond_if_const,
2430 TosaTensorGen.tgBasic,
2431 TosaArgGen.agCondIf,
2432 ),
2433 "types": [DType.BOOL],
2434 },
2435 "cond_if_binary": {
2436 "op": Op.COND_IF,
2437 "operands": (2, 0),
2438 "build_fcn": (
2439 build_cond_if_binary,
2440 TosaTensorGen.tgBasic,
2441 TosaArgGen.agCondIf,
2442 ),
2443 "types": TYPE_FI32,
2444 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002445 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002446 "while_loop": {
2447 "op": Op.WHILE_LOOP,
2448 "operands": (0, 1),
2449 "build_fcn": (
2450 build_while_loop,
2451 TosaTensorGen.tgBasic,
2452 TosaArgGen.agWhileLoop,
2453 ),
2454 "types": [DType.INT32],
2455 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002456 }
2457
Kevin Cheng550ccc52021-03-03 11:21:43 -08002458
Eric Kunzee5e26762020-10-13 16:11:07 -07002459class OutputShaper:
2460 # Methods in this class compute the expected output shape and datatype
2461 # for common classes of operations
2462 def __init__(self):
2463 pass
2464
2465 # These methods return arguments that can be used for
2466 # creating a new output tensor
2467 @staticmethod
2468 def binaryBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002469 assert len(a.shape) == len(b.shape)
2470 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002471
2472 shape = []
2473 for i in range(len(a.shape)):
2474 if a.shape[i] == 1:
2475 shape.append(b.shape[i])
2476 else:
2477 shape.append(a.shape[i])
2478
Kevin Cheng550ccc52021-03-03 11:21:43 -08002479 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002480
2481 @staticmethod
2482 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002483 assert len(a.shape) == len(b.shape)
2484 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002485
2486 shape = []
2487 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002488 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07002489 shape.append(a.shape[i])
2490
Kevin Cheng550ccc52021-03-03 11:21:43 -08002491 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002492
2493 @staticmethod
2494 def unaryOp(ser, a):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002495 return ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002496
2497 @staticmethod
2498 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002499 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
2500 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002501
2502 shape = []
2503 for i in range(len(a.shape)):
2504 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
2505
Kevin Cheng550ccc52021-03-03 11:21:43 -08002506 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002507
2508 @staticmethod
2509 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002510 assert len(a.shape) == len(b.shape)
2511 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002512
2513 # Do broadcast
2514 shape = []
2515 for i in range(len(a.shape)):
2516 if a.shape[i] == 1:
2517 shape.append(b.shape[i])
2518 else:
2519 shape.append(a.shape[i])
2520
2521 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08002522 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07002523
2524 @staticmethod
2525 def reduceOp(ser, a, axis):
2526
2527 shape = a.shape.copy()
2528
2529 shape[axis] = 1
2530
Kevin Cheng550ccc52021-03-03 11:21:43 -08002531 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002532
2533 @staticmethod
2534 def argmaxOp(ser, a, axis):
2535 shape = a.shape.copy()
2536 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002537 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002538
2539 @staticmethod
2540 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
2541
2542 # IFM: NHWC
2543 # Filter: OHWI
2544 # OFM: NHWC
2545
2546 if len(padding) == 2:
2547 # Expand padding to 4 parameters in the case of transpose_conv2d
2548 # From H,W to T,B,L,R
2549 padding = [padding[0], padding[0], padding[1], padding[1]]
2550
Kevin Cheng550ccc52021-03-03 11:21:43 -08002551 h = (
2552 ifm.shape[1]
2553 - filter.shape[1]
2554 - (filter.shape[1] - 1) * (dilations[0] - 1)
2555 + padding[0]
2556 + padding[1]
2557 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002558
Kevin Cheng550ccc52021-03-03 11:21:43 -08002559 w = (
2560 ifm.shape[2]
2561 - filter.shape[2]
2562 - (filter.shape[2] - 1) * (dilations[1] - 1)
2563 + padding[2]
2564 + padding[3]
2565 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002566
2567 if h <= 0 or w <= 0:
2568 # Invalid test parameters?
2569 h = 0
2570 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002571 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002572
2573 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
2574
Kevin Cheng3a478572021-01-22 17:21:02 -08002575 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002576 out_dtype = DType.INT32
2577 elif ifm.dtype == DType.INT16:
2578 out_dtype = DType.INT48
2579 elif ifm.dtype == DType.FLOAT:
2580 out_dtype = DType.FLOAT
2581 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002582 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002583
Kevin Cheng550ccc52021-03-03 11:21:43 -08002584 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002585
2586 @staticmethod
2587 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
2588 # IFM: NHWC
2589 # Filter: HWCM
2590 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08002591 h = (
2592 ifm.shape[1]
2593 - filter.shape[0]
2594 - (filter.shape[0] - 1) * (dilations[0] - 1)
2595 + padding[0]
2596 + padding[1]
2597 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002598
Kevin Cheng550ccc52021-03-03 11:21:43 -08002599 w = (
2600 ifm.shape[2]
2601 - filter.shape[1]
2602 - (filter.shape[1] - 1) * (dilations[1] - 1)
2603 + padding[2]
2604 + padding[3]
2605 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002606
2607 if h <= 0 or w <= 0:
2608 # Invalid test parameters?
2609 h = 0
2610 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002611 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002612
2613 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
2614
Kevin Cheng3a478572021-01-22 17:21:02 -08002615 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002616 out_dtype = DType.INT32
2617 elif ifm.dtype == DType.INT16:
2618 out_dtype = DType.INT48
2619 elif ifm.dtype == DType.FLOAT:
2620 out_dtype = DType.FLOAT
2621 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002622 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002623
Kevin Cheng550ccc52021-03-03 11:21:43 -08002624 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002625
2626 @staticmethod
2627 def pool2dOp(ser, ifm, kernel, stride, pad):
2628 # input: NHWC
2629 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
2630 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
2631
2632 if h <= 0 or w <= 0:
2633 # Invalid test parameters?
2634 h = 0
2635 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002636 ser.setExpectedFailure(True, "Invalid combination of pooling parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002637
2638 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002639 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002640
2641 @staticmethod
2642 def fullyConnectedOp(ser, input, filter):
2643 # input: N, IC
2644 # filter: OC, IC
2645 # output: N, OC
2646
2647 output_shape = [input.shape[0], filter.shape[0]]
2648
Kevin Cheng3a478572021-01-22 17:21:02 -08002649 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002650 out_dtype = DType.INT32
2651 elif input.dtype == DType.INT16:
2652 out_dtype = DType.INT48
2653 elif input.dtype == DType.FLOAT:
2654 out_dtype = DType.FLOAT
2655 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002656 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002657
Kevin Cheng550ccc52021-03-03 11:21:43 -08002658 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002659
2660 @staticmethod
2661 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07002662 # a: N, H, C
2663 # b: N, C, W
2664 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07002665
Kevin Cheng2d60f002021-06-09 14:18:32 -07002666 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002667
Kevin Cheng3a478572021-01-22 17:21:02 -08002668 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002669 out_dtype = DType.INT32
2670 elif a.dtype == DType.INT16:
2671 out_dtype = DType.INT48
2672 elif a.dtype == DType.FLOAT:
2673 out_dtype = DType.FLOAT
2674 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002675 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002676
Kevin Cheng550ccc52021-03-03 11:21:43 -08002677 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002678
2679 @staticmethod
2680 def concatOp(ser, a, b, axis):
2681
2682 output_shape = a.shape.copy()
2683 output_shape[axis] = a.shape[axis] + b.shape[axis]
2684
Kevin Cheng550ccc52021-03-03 11:21:43 -08002685 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002686
2687 @staticmethod
2688 def padOp(ser, a, padding):
2689
2690 output_shape = a.shape.copy()
2691
2692 for i in range(len(output_shape)):
2693 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
2694
Kevin Cheng550ccc52021-03-03 11:21:43 -08002695 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002696
2697 @staticmethod
2698 def reshapeOp(ser, a, shape):
2699 output_shape = shape.copy()
2700
2701 totalElements = 1
2702 for i in a.shape:
2703 totalElements *= i
2704
2705 # If there are any -1 elements, figure out what that dimension must be
2706 totalOutputElements = 1
2707 for i in output_shape:
2708 if i != -1:
2709 totalOutputElements *= i
2710
2711 # And fill it in
2712 for i in range(len(output_shape)):
2713 if output_shape[i] == -1:
2714 output_shape[i] = totalElements // totalOutputElements
2715
Kevin Cheng550ccc52021-03-03 11:21:43 -08002716 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002717
2718 @staticmethod
2719 def sliceOp(ser, a, begin, size):
2720
2721 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002722 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002723
2724 @staticmethod
2725 def tileOp(ser, a, multiples):
2726
2727 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002728 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002729
2730 for i in range(len(output_shape)):
2731 output_shape[i] = a.shape[i] * multiples[i]
2732
Kevin Cheng550ccc52021-03-03 11:21:43 -08002733 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002734
2735 @staticmethod
2736 def transposeOp(ser, a, perms):
2737 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002738 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002739
2740 for i in range(len(output_shape)):
2741 output_shape[i] = a.shape[perms[i]]
2742
Kevin Cheng550ccc52021-03-03 11:21:43 -08002743 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002744
2745 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08002746 def gatherOp(ser, values, indices):
2747 assert len(values.shape) == 3
2748 assert len(indices.shape) == 2
2749 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07002750
Kevin Cheng77d0f762020-11-24 10:26:32 -08002751 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
2752
Kevin Cheng550ccc52021-03-03 11:21:43 -08002753 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002754
2755 @staticmethod
2756 def scatterOp(ser, values_in, indices, input):
2757 assert len(values_in.shape) == 3
2758 assert len(indices.shape) == 2
2759 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08002760 assert values_in.shape[0] == indices.shape[0] # N
2761 assert input.shape[1] == indices.shape[1] # W
2762 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08002763
2764 output_shape = values_in.shape
2765
Kevin Cheng550ccc52021-03-03 11:21:43 -08002766 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002767
2768 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002769 def tableOp(ser, input, table_dtype):
2770 # Same shape as the input, but dtype dependent on table dtype
2771 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
2772 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
2773 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002774
2775 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08002776 def resizeOp(
2777 ser,
2778 input,
2779 mode,
2780 stride,
2781 offset,
2782 shift,
2783 stride_fp,
2784 offset_fp,
2785 output_dims,
2786 input_dtype,
2787 output_dtype,
2788 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002789
2790 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
2791
Kevin Cheng77d0f762020-11-24 10:26:32 -08002792 if input_dtype == DType.FLOAT:
2793 if stride_fp[0] <= 0 or stride_fp[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002794 ser.setExpectedFailure(True, "Negative or zero stride")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002795 else:
2796 if stride[0] <= 0 or stride[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002797 ser.setExpectedFailure(True, "Negative or zero stride")
Eric Kunzee5e26762020-10-13 16:11:07 -07002798
Kevin Chengaee1fac2020-11-11 13:54:06 -08002799 if mode == ResizeMode.BILINEAR:
2800 if input_dtype == DType.INT8:
2801 if output_dtype != DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002802 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002803 elif input_dtype == DType.INT16:
2804 if output_dtype != DType.INT48:
Kevin Cheng989cb052021-04-28 16:29:44 -07002805 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002806 elif input_dtype == DType.FLOAT:
2807 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002808 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002809 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002810 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002811
2812 elif mode == ResizeMode.NEAREST:
2813 if input_dtype == DType.INT8:
2814 if output_dtype != DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002815 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002816 elif input_dtype == DType.INT16:
2817 if output_dtype != DType.INT16:
Kevin Cheng989cb052021-04-28 16:29:44 -07002818 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002819 elif input_dtype == DType.FLOAT:
2820 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002821 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002822 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002823 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002824
2825 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002826 ser.setExpectedFailure(true, "Invalid resize mode")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002827
Kevin Cheng550ccc52021-03-03 11:21:43 -08002828 return ser.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002829
2830 @staticmethod
2831 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002832 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002833
2834 @staticmethod
2835 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08002836 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002837 out_dtype = DType.INT32
2838 elif ifm.dtype == DType.INT16:
2839 out_dtype = DType.INT48
2840 elif ifm.dtype == DType.FLOAT:
2841 out_dtype = DType.FLOAT
2842 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002843 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002844
2845 if output_shape[1] <= 0 or output_shape[2] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002846 ser.setExpectedFailure(True, "Negative output shape")
Eric Kunzee5e26762020-10-13 16:11:07 -07002847
Kevin Cheng550ccc52021-03-03 11:21:43 -08002848 return ser.addOutput(output_shape, out_dtype)