blob: 4820503e2de527245e7cc40792785ee980c22ab2 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
35
Kevin Cheng550ccc52021-03-03 11:21:43 -080036# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
37parent_dir = os.path.dirname(os.path.realpath(__file__))
38sys.path.append(
39 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
40)
Eric Kunzee5e26762020-10-13 16:11:07 -070041import tosa_serializer as ts
42from tosa_serializer import *
43import tosa
44
45# Convenience variables to the flatc-generated types that should be enums, but aren't
46DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080047Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070048ResizeMode = tosa.ResizeMode.ResizeMode()
49
Kevin Cheng550ccc52021-03-03 11:21:43 -080050
Eric Kunzee5e26762020-10-13 16:11:07 -070051class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080052 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
53
Eric Kunzee5e26762020-10-13 16:11:07 -070054 def __init__(self):
55 pass
56
57 @staticmethod
58 def needsQinfo(op, dtype):
Jared Smolens2a76ad22021-03-04 11:18:54 -080059 if dtype == DType.INT8 or dtype == DType.INT16:
Eric Kunzee5e26762020-10-13 16:11:07 -070060 return True
61 return False
62
63 @staticmethod
64 def qgUnary(testGen, op, dtype):
65 qinfo = ts.TosaSerializerQuantInfo()
66 if TosaQuantGen.needsQinfo(op, dtype):
67 qinfo.UnaryQuantInfo(testGen.randInt(), testGen.randInt())
68 else:
69 qinfo.UnaryQuantInfo(0, 0)
70 return qinfo
71
72 @staticmethod
73 def qgConv(testGen, op, dtype):
74 qinfo = ts.TosaSerializerQuantInfo()
75 if TosaQuantGen.needsQinfo(op, dtype):
76 qinfo.ConvQuantInfo(testGen.randInt(), testGen.randInt())
77 else:
78 qinfo.ConvQuantInfo(0, 0)
79 return qinfo
80
81 @staticmethod
82 def qgMatmul(testGen, op, dtype):
83 qinfo = ts.TosaSerializerQuantInfo()
84 if TosaQuantGen.needsQinfo(op, dtype):
85 qinfo.MatMulQuantInfo(testGen.randInt(), testGen.randInt())
86 else:
87 qinfo.MatMulQuantInfo(0, 0)
88 return qinfo
89
90 @staticmethod
91 def qgPad(testGen, op, dtype):
92 qinfo = ts.TosaSerializerQuantInfo()
93 if TosaQuantGen.needsQinfo(op, dtype):
94 qinfo.PadQuantInfo(testGen.randInt())
95 else:
96 qinfo.PadQuantInfo(0)
97 return qinfo
98
99 @staticmethod
100 def computeMultiplierAndShift(scaleFp, scale32):
101 # Derived from computeMultiplierAndShiftTosaScale32
102 # Provide a floating-point scaling factor and the scale32 parameter
103 # to compute the multiplier and shift
104
105 if scale32:
106 scaleBits = 31
107 else:
108 scaleBits = 15
109
110 m, shift = math.frexp(scaleFp)
111
112 if scaleFp < 0.0:
113 m = -m
114
115 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800116 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700117
118 if multiplier == (1 << scaleBits):
119 multiplier = multiplier // 2
120 shift = shift + 1
121
122 shift = (-shift) + scaleBits
Kevin Cheng550ccc52021-03-03 11:21:43 -0800123 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
Eric Kunzee5e26762020-10-13 16:11:07 -0700124
Kevin Cheng550ccc52021-03-03 11:21:43 -0800125 assert multiplier <= (1 << scaleBits)
126 assert shift >= 0 and shift <= 63
Eric Kunzee5e26762020-10-13 16:11:07 -0700127
128 return multiplier, shift
129
130
Kevin Cheng550ccc52021-03-03 11:21:43 -0800131class TosaTensorGen:
132 """Tensor generators create a shape list for the placeholder and const tensor
133 data operands for the operator. The actual random data is generated separately for each test."""
134
Eric Kunzee5e26762020-10-13 16:11:07 -0700135 def __init__(self):
136 pass
137
138 @staticmethod
139 def tgBasic(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800140 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 shape = testGen.makeShape(rank)
142
143 shape_list = []
144 for i in range(pl + const):
145 shape_list.append(shape.copy())
146
147 return shape_list
148
149 @staticmethod
150 def tgNHWC(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800151 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700152
Kevin Cheng550ccc52021-03-03 11:21:43 -0800153 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700154
155 shape = testGen.makeShape(rank)
156
157 # Constrict the batch size?
158 if testGen.args.max_batch_size:
159 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
160
161 shape_list = []
162 for i in range(pl + const):
163 shape_list.append(shape.copy())
164
165 return shape_list
166
167 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -0800168 def tgScatter(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800169 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800170
Kevin Cheng550ccc52021-03-03 11:21:43 -0800171 assert pl == 2
172 assert const == 0
173 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800174
175 values_in_shape = testGen.makeShape(rank)
176
177 # Constrict the batch size?
178 if testGen.args.max_batch_size:
179 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
180
Kevin Cheng550ccc52021-03-03 11:21:43 -0800181 W = testGen.randInt(
182 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
183 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800184 input_shape = [values_in_shape[0], W, values_in_shape[2]]
185
186 shape_list = []
187 shape_list.append(values_in_shape.copy())
188 shape_list.append(input_shape.copy())
189
190 return shape_list
191
192 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 def tgBroadcastFuzz(testGen, op, rank):
194 shape = testGen.makeShape(rank)
195
Kevin Cheng550ccc52021-03-03 11:21:43 -0800196 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700197
198 shape_list = []
199
200 # Choose one of the inputs to broadcast
201 bcast_idx = testGen.randInt(0, pl + const)
202 for i in range(pl + const):
203 shape_bcast = shape.copy()
204
205 # If the chosen input, pick a random index to broadcast
206 if i == bcast_idx:
207 fuzz_idx = testGen.randInt(0, rank)
208 shape_bcast[fuzz_idx] = 1
209
210 shape_list.append(shape_bcast)
211
212 return shape_list
213
214 @staticmethod
215 def tgConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800216 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700217
Kevin Cheng550ccc52021-03-03 11:21:43 -0800218 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700219
220 # IFM dimensions are NHWC
221 ifm_shape = testGen.makeShape(rank)
222
223 # Constrict the batch size?
224 if testGen.args.max_batch_size:
225 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
226
227 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800228 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700229
230 # Generate a random OFM depth
231 ofm_depth = testGen.makeShape(1)[0]
232
233 # The filter dimensions are OHWI
234 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
235
236 # The bias is OC
237 bias_shape = np.asarray([ofm_depth])
238
239 return [ifm_shape, filter_shape, bias_shape]
240
241 @staticmethod
242 def tgTransposeConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800243 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700244
Kevin Cheng550ccc52021-03-03 11:21:43 -0800245 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700246
247 # IFM dimensions are NHWC
248 ifm_shape = testGen.makeShape(rank)
249
250 # Constrict the batch size?
251 if testGen.args.max_batch_size:
252 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
253
254 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800255 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700256
257 # Generate a random OFM depth
258 ofm_depth = testGen.makeShape(1)[0]
259
260 # The filter dimensions are OHWI
261 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
262
Kevin Cheng989cb052021-04-28 16:29:44 -0700263 # The bias is OC
264 bias_shape = np.asarray([ofm_depth])
265
266 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700267
268 @staticmethod
269 def tgDepthwiseConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800270 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700271
Kevin Cheng550ccc52021-03-03 11:21:43 -0800272 assert rank == 4
273 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
275 # IFM dimensions are NHWC
276 ifm_shape = testGen.makeShape(rank)
277
278 # Constrict the batch size?
279 if testGen.args.max_batch_size:
280 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
281
282 # Get the filter height/width from the operator parameters
283 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
286 # Generate a random OFM depth, but don't let it get too big because
287 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 filter_m = (
289 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
290 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700291
292 # The filter dimensions are HWCM
293 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
294
295 # The bias is M * C
296 bias_shape = np.asarray([ifm_shape[3] * filter_m])
297
298 return [ifm_shape, filter_shape, bias_shape]
299
300 @staticmethod
301 def tgFullyConnected(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800302 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700303
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700305
306 input_shape = testGen.makeShape(rank)
307 filter_oc = testGen.makeShape(1)[0]
308 filter_shape = np.asarray([filter_oc, input_shape[1]])
309
310 bias_shape = np.asarray([filter_oc])
311
312 return [input_shape, filter_shape, bias_shape]
313
314 @staticmethod
315 def tgMatmul(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800316 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700317
Kevin Cheng2d60f002021-06-09 14:18:32 -0700318 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800319 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700320
321 a_shape = testGen.makeShape(rank)
322 b_oc = testGen.makeShape(1)[0]
Kevin Cheng2d60f002021-06-09 14:18:32 -0700323 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700324
325 return [a_shape, b_shape]
326
Kevin Cheng550ccc52021-03-03 11:21:43 -0800327
Eric Kunzee5e26762020-10-13 16:11:07 -0700328class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800329 """Argument generators create exhaustive or random lists of attributes for operators that take
330 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
331 tuples where the descriptive_name is appended to the test name and the arglist is expanded
332 as arguments to the operator build function."""
333
Eric Kunzee5e26762020-10-13 16:11:07 -0700334 def __init__(self):
335 pass
336
337 @staticmethod
338 def agNone(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800339 """A trivial argument generator for operators that don't take any
340 non-tensor arguments"""
341 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700342
343 @staticmethod
344 def agAxis(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800345 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700346 axes = []
347
348 shape = shapeList[0]
349
350 for a in range(0, len(shape)):
Matthew Haddon43e37192021-07-09 14:13:02 +0100351 axes.append(("axis{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700352 return axes
353
354 @staticmethod
355 def agConv2D(testGen, opName, shapeList, dtype):
356 arg_list = []
357
358 ifm_shape = shapeList[0]
359 filter_shape = shapeList[1]
360
361 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800362 assert len(ifm_shape) == 4
363 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700364
365 maxStride = testGen.args.max_conv_stride
366 maxPadding = testGen.args.max_conv_padding + 1
367 maxDilation = testGen.args.max_conv_dilation
368
369 # Strides, padding, dilations
370 for stride in range(0, maxStride ** 2):
371 for padding in range(0, (maxPadding) ** 4):
372 for dilation in range(0, maxDilation ** 2):
373
Kevin Cheng550ccc52021-03-03 11:21:43 -0800374 s = [stride // maxStride + 1, stride % maxStride + 1]
375 p = [
376 (padding // (maxPadding * 4)) % maxPadding,
377 (padding // (maxPadding * 2)) % maxPadding,
378 (padding // (maxPadding * 1)) % maxPadding,
379 padding % maxPadding,
380 ]
381 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700382
383 # 4 padding parameters for regular conv2d
Kevin Cheng550ccc52021-03-03 11:21:43 -0800384 arg_list.append(
385 (
386 "st{}{}_pad{}{}{}{}_dilat{}{}".format(
387 s[0], s[1], p[0], p[1], p[2], p[3], d[0], d[1]
388 ),
389 [s, p, d],
390 )
391 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700392 return arg_list
393
394 @staticmethod
395 def agTransposeConv2D(testGen, opName, shapeList, dtype):
396 arg_list = []
397
398 ifm_shape = shapeList[0]
399 filter_shape = shapeList[1]
400
401 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800402 assert len(ifm_shape) == 4
403 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700404
405 maxStride = testGen.args.max_conv_stride
406 maxPadding = testGen.args.max_conv_padding + 1
407 maxDilation = testGen.args.max_conv_dilation
408
409 # Strides, padding, dilations
410 for stride in range(0, maxStride ** 2):
411 for out_padding in range(0, (maxPadding) ** 2):
412 for dilation in range(0, maxDilation ** 2):
413
Kevin Cheng550ccc52021-03-03 11:21:43 -0800414 s = [stride // maxStride + 1, stride % maxStride + 1]
415 p = [
416 (out_padding // (maxPadding * 1)) % maxPadding,
417 out_padding % maxPadding,
418 ]
419 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700420
Kevin Cheng550ccc52021-03-03 11:21:43 -0800421 oh = (
422 ifm_shape[1]
423 - filter_shape[1]
424 - (filter_shape[1] - 1) * (d[0] - 1)
425 + 2 * p[0]
426 ) // s[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700427
Kevin Cheng550ccc52021-03-03 11:21:43 -0800428 ow = (
429 ifm_shape[2]
430 - filter_shape[2]
431 - (filter_shape[2] - 1) * (d[1] - 1)
432 + 2 * p[1]
433 ) // s[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700434
435 # Output shape
Kevin Cheng550ccc52021-03-03 11:21:43 -0800436 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Eric Kunzee5e26762020-10-13 16:11:07 -0700437
Kevin Cheng550ccc52021-03-03 11:21:43 -0800438 arg_list.append(
439 (
440 "st{}{}_outpad{}{}_dilat{}{}_os{}x{}x{}x{}".format(
441 s[0],
442 s[1],
443 p[0],
444 p[1],
445 d[0],
446 d[1],
447 os[0],
448 os[1],
449 os[2],
450 os[3],
451 ),
452 [s, p, d, os],
453 )
454 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700455
456 return arg_list
457
458 @staticmethod
459 def agPad(testGen, opName, shapeList, dtype):
460 arg_list = []
461 rank = len(shapeList[0])
462
463 # Exhaustively test combinations of 0/1 padding on each side of each dimension
464 # This process might need some revision for >1 padding, but use rank**2 as a bitmask
465 # for now
466 for v in range(rank ** 2):
467
468 # Create a flat arraypadding4D
469 paddings = np.zeros((rank * 2), dtype=np.int32)
470
471 # Fill in the 1's
Kevin Cheng550ccc52021-03-03 11:21:43 -0800472 for r in range(rank * 2):
Eric Kunzee5e26762020-10-13 16:11:07 -0700473 if (v >> r) & 1:
474 paddings[r] = 1
475
476 # Reshape back to a 2D array
477 paddings = paddings.reshape((rank, 2))
478
Kevin Cheng550ccc52021-03-03 11:21:43 -0800479 arg_list.append(("pad{0:b}".format(v), [paddings]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700480
481 return arg_list
482
483 @staticmethod
484 def agPooling(testGen, opName, shapeList, dtype):
485 arg_list = []
486
487 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800488 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700489
490 maxStride = testGen.args.max_pooling_stride
491 maxKernel = testGen.args.max_pooling_kernel
492 maxPadding = testGen.args.max_pooling_padding + 1
493
494 for kernel in range(0, maxKernel ** 2):
495 for stride in range(0, maxStride ** 2):
496 for padding in range(0, maxPadding ** 4):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800497 s = [stride // maxStride + 1, stride % maxStride + 1]
498 k = [(kernel // maxKernel) + 2, (kernel % maxKernel) + 2]
499 p = [
500 (padding // (maxPadding * 4)) % maxPadding,
501 (padding // (maxPadding * 2)) % maxPadding,
502 (padding // (maxPadding * 1)) % maxPadding,
503 padding % maxPadding,
504 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700505
Kevin Cheng550ccc52021-03-03 11:21:43 -0800506 arg_list.append(
507 (
508 "st{}{}_kern{}{}_pad{}{}{}{}".format(
509 s[0], s[1], k[0], k[1], p[0], p[1], p[2], p[3]
510 ),
511 [k, s, p],
512 )
513 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700514 return arg_list
515
516 @staticmethod
517 def agCast(testGen, opName, shapeList, inDtype):
518 arg_list = []
519
520 # Enumerate the output types here
521 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800522 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700523 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800524 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700525 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800526 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700527 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800528 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700529 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800530 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700531 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800532 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700533
534 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800535 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700536
537 return arg_list
538
539 @staticmethod
540 def agRescale(testGen, opName, shapeList, inDtype):
541 arg_list = []
542
543 # Enumerate the output types here
Kevin Cheng550ccc52021-03-03 11:21:43 -0800544 for dtype in [DType.INT8, DType.INT16, DType.INT32]:
545 for scale32 in [False, True]:
546 for double_round in [False, True]:
547 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700548
549 if inDtype == DType.INT48 and scale32:
550 # Illegal condition. Must be scale32=False
551 continue
552
Kevin Cheng550ccc52021-03-03 11:21:43 -0800553 arg_list.append(
554 (
555 "out{}_sc{}_dr{}_pc{}".format(
556 DTypeNames[dtype],
557 int(scale32),
558 int(double_round),
559 int(per_channel),
560 ),
561 [dtype, scale32, double_round, per_channel],
562 )
563 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700564
565 return arg_list
566
Kevin Chengaee1fac2020-11-11 13:54:06 -0800567 @staticmethod
568 def agMul(testGen, opName, shapeList, dtype):
569 arg_list = []
570
571 if dtype is DType.INT32:
572 for p in range(testGen.args.num_rand_permutations):
573
574 shift = testGen.randInt(0, 32)
575
Kevin Cheng550ccc52021-03-03 11:21:43 -0800576 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800577 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100578 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800579
580 return arg_list
581
582 @staticmethod
583 def agArithmeticRightShift(testGen, opName, shapeList, dtype):
584 arg_list = []
585
Kevin Cheng550ccc52021-03-03 11:21:43 -0800586 arg_list.append(("roundTrue", [True]))
587 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800588
589 return arg_list
590
Eric Kunzee5e26762020-10-13 16:11:07 -0700591 # Helper function for reshape. Gets some factors of a larger number.
592 @staticmethod
593 def getFactors(val, start=1):
594 factors = []
595
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100596 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700597 if (val % i) == 0:
598 factors.append(i)
599
600 return factors
601
602 @staticmethod
603 def agReshape(testGen, opName, shapeList, dtype):
604 arg_list = []
605
606 origShape = shapeList[0]
607
608 totalElements = 1
609 for s in origShape:
610 totalElements *= s
611
612 # This code is NOT fast. Fortunately, the numbers are fairly small.
613 factors = TosaArgGen.getFactors(totalElements)
614
615 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100616 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800617 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700618 continue
619
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100620 found = True
621 # escape_counter breaks while loop if it continues on for too long
622 escape_counter = 0
623 while found:
624 newShape = []
625 # Generate newShape ensuring it isn't a duplicate
626 remainingElements = totalElements
627 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100628 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100629 # pick rank-1 factors
630 newShape.append(shuffledFactors[0])
631 remainingElements = remainingElements // shuffledFactors[0]
632 shuffledFactors = testGen.rng.permutation(
633 TosaArgGen.getFactors(remainingElements)
634 )
635 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700636
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100637 # Toss in a -1 sometimes
638 minusOne = testGen.randInt(0, newRank * 4)
639 if minusOne < newRank:
640 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700641
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100642 # Check for duplicates
643 found = False
644 for name, other_shape in arg_list:
645 if other_shape[0] == newShape:
646 found = True
647 break
648
649 escape_counter += 1
650 if escape_counter >= 100:
651 break
652
653 if not found:
654 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700655
656 return arg_list
657
Eric Kunzee5e26762020-10-13 16:11:07 -0700658 @staticmethod
659 def agTranspose(testGen, opName, shapeList, dtype):
660 arg_list = []
661
662 ifm_shape = shapeList[0]
663
Jeremy Johnsona6185572021-06-21 15:55:35 +0100664 # Get all permutations
665 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700666
Jeremy Johnsona6185572021-06-21 15:55:35 +0100667 # Limit to possible permutations from shape dimension or argument setting
668 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700669
Jeremy Johnsona6185572021-06-21 15:55:35 +0100670 # Get random permutation generator that uses all permutations
671 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700672
Jeremy Johnsona6185572021-06-21 15:55:35 +0100673 # Create list of required amount of permutations
674 arg_list = [("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit)]
Eric Kunzee5e26762020-10-13 16:11:07 -0700675 return arg_list
676
677 @staticmethod
678 def agSlice(testGen, opName, shapeList, dtype):
679 arg_list = []
680
681 ifm_shape = shapeList[0]
682 rank = len(ifm_shape)
683
684 for p in range(testGen.args.num_rand_permutations):
685 begin = []
686 size = []
687
Kevin Cheng550ccc52021-03-03 11:21:43 -0800688 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700689
690 for i in range(rank):
691 if ifm_shape[i] > 1:
692 begin.append(testGen.randInt(0, ifm_shape[i]))
693 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
694
695 # Invalid slice size?
696 if size[i] == 0:
697 valid = False
698 else:
699 begin.append(0)
700 size.append(1)
701
702 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800703 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700704 return arg_list
705
706 @staticmethod
707 def agTile(testGen, opName, shapeList, dtype):
708 arg_list = []
709
710 ifm_shape = shapeList[0]
711 rank = len(ifm_shape)
712
713 for p in range(testGen.args.num_rand_permutations):
714
715 # Pick a few random, but small multiple values
716 # because otherwise this has a tendency to generate
717 # enormous tensors
718 multiples = []
719 for i in range(rank):
720 multiples.append(testGen.randInt(1, 4))
721
Kevin Cheng550ccc52021-03-03 11:21:43 -0800722 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700723
724 return arg_list
725
726 @staticmethod
727 def agResize(testGen, opName, shapeList, dtype):
728 arg_list = []
729
730 ifm_shape = shapeList[0]
731
732 for m in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
733
734 # Exclude illegal {mode, type} configurations. Pick legal output types
735 if m == ResizeMode.NEAREST and dtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800736 outputDTypeList = [DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700737 elif m == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800738 outputDTypeList = [DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -0700739 elif m == ResizeMode.BILINEAR and dtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800740 outputDTypeList = [DType.INT8]
Eric Kunzee5e26762020-10-13 16:11:07 -0700741 elif m == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800742 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800743 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800744 outputDTypeList = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700745 else:
746 continue
747
748 for outputDType in outputDTypeList:
749 for perm in range(testGen.args.num_rand_permutations):
750
751 # Randomly generate legal output dimensions and shift
752 # and then compute the stride and offset based on them
Kevin Cheng550ccc52021-03-03 11:21:43 -0800753 output_dims = [testGen.randInt(1), testGen.randInt(1)]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800754 in_center_h = (ifm_shape[1] - 1) / 2.0
755 in_center_w = (ifm_shape[2] - 1) / 2.0
756 out_center_h = (output_dims[0] - 1) / 2.0
757 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700758
Kevin Cheng77d0f762020-11-24 10:26:32 -0800759 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
760 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
761 fp_offset_y = in_center_h - fp_stride_y * out_center_h
762 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700763
Kevin Cheng77d0f762020-11-24 10:26:32 -0800764 if outputDType == DType.FLOAT:
765 shift = 0
766 stride = [0, 0]
767 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800768 stride_fp = [fp_stride_y, fp_stride_x]
769 offset_fp = [fp_offset_y, fp_offset_x]
770 arg_list.append(
771 (
772 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
773 m,
774 output_dims[0],
775 output_dims[1],
776 testGen.typeStr(outputDType),
777 stride_fp[0],
778 stride_fp[1],
779 offset_fp[0],
780 offset_fp[1],
781 ),
782 [
783 m,
784 stride,
785 offset,
786 shift,
787 stride_fp,
788 offset_fp,
789 output_dims,
790 dtype,
791 outputDType,
792 ],
793 )
794 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800795 else:
796 shift = 11
797 unit = float(1 << shift)
798 stride_y = int(round(fp_stride_y * unit))
799 stride_x = int(round(fp_stride_x * unit))
800 offset_y = int(round(fp_offset_y * unit))
801 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700802
Kevin Cheng550ccc52021-03-03 11:21:43 -0800803 while (
804 stride_y >= 32768
805 or stride_x >= 32768
806 or offset_y >= 32768
807 or offset_x >= 32768
808 or offset_y < -32768
809 or offset_x < -32768
810 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -0800811 shift = shift - 1
812 unit = float(1 << shift)
813 stride_y = int(round(fp_stride_y * unit))
814 stride_x = int(round(fp_stride_x * unit))
815 offset_y = int(round(fp_offset_y * unit))
816 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700817
Kevin Cheng550ccc52021-03-03 11:21:43 -0800818 stride = [stride_y, stride_x]
819 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800820
821 stride_fp = [0.0, 0.0]
822 offset_fp = [0.0, 0.0]
823
Kevin Cheng550ccc52021-03-03 11:21:43 -0800824 arg_list.append(
825 (
826 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
827 m,
828 shift,
829 output_dims[0],
830 output_dims[1],
831 testGen.typeStr(outputDType),
832 stride[0],
833 stride[1],
834 offset[0],
835 offset[1],
836 ),
837 [
838 m,
839 stride,
840 offset,
841 shift,
842 stride_fp,
843 offset_fp,
844 output_dims,
845 dtype,
846 outputDType,
847 ],
848 )
849 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700850
851 return arg_list
852
853 def agCondIf(testGen, opName, shapeList, dtype):
854 # CondIf generates the condition values here.
855 # Convert to tensors in the build function, along with the
856 # then and else blocks
857 arg_list = []
858
859 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800860 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700861
862 return arg_list
863
864 def agWhileLoop(testGen, opName, shapeList, dtype):
865 # While loop: 0 iterations, 1, more than 1
866 arg_list = []
867
868 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800869 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700870
871 return arg_list
872
Kevin Cheng550ccc52021-03-03 11:21:43 -0800873
Eric Kunzee5e26762020-10-13 16:11:07 -0700874class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +0100875 # Maximum rank of tensor supported by test generator.
876 TOSA_TENSOR_MAX_RANK = 6
877
Eric Kunzee5e26762020-10-13 16:11:07 -0700878 def __init__(self, args):
879 self.args = args
880 self.basePath = args.output_dir
881 self.random_seed = args.random_seed
882 self.ser = None
883 self.rng = np.random.default_rng(self.random_seed)
884 self.createDynamicOpLists()
885 self.initOpListDefaults()
886 self.quantGen = TosaQuantGen()
887 # Force makeShape to do a specific starting shape
888 self.targetted_shape = None
889
890 def createSerializer(self, opName, testPath):
891 self.testPath = os.path.join(opName, testPath)
892
893 fullPath = os.path.join(self.basePath, self.testPath)
894 os.makedirs(fullPath, exist_ok=True)
895 self.ser = ts.TosaSerializer(fullPath)
896
897 def getSerializer(self):
898 return self.ser
899
900 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800901 with open(
902 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
903 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700904 fd.write(self.ser.serialize())
905
Kevin Cheng550ccc52021-03-03 11:21:43 -0800906 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
907 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700908
909 def getRandTensor(self, shape, dtype):
910 RAND_SHIFT_FACTOR = 0.5
911 RAND_SCALE_FACTOR = 4.0
912
913 if dtype == DType.BOOL:
914 np_dt = np.bool
915 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700916 elif dtype == DType.INT4:
917 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
918 elif dtype == DType.INT8:
919 return np.int32(self.rng.integers(low=-127, high=128, size=shape))
920 elif dtype == DType.INT16:
921 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
922 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800923 return np.int32(
924 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
925 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700926 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800927 return np.int64(
928 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
929 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700930 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800931 return np.float32(
932 self.rng.random(size=shape) - RAND_SHIFT_FACTOR * RAND_SCALE_FACTOR
933 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700934 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800935 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700936
Kevin Cheng989cb052021-04-28 16:29:44 -0700937 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700938 placeholders = []
939
Kevin Cheng989cb052021-04-28 16:29:44 -0700940 assert len(shape_list) == len(dtype_list)
941
942 for idx, shape in enumerate(shape_list):
943 arr = self.getRandTensor(shape, dtype_list[idx])
944 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700945
946 return placeholders
947
Kevin Cheng989cb052021-04-28 16:29:44 -0700948 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700949 consts = []
950
Kevin Cheng989cb052021-04-28 16:29:44 -0700951 assert len(shape_list) == len(dtype_list)
952
953 for idx, shape in enumerate(shape_list):
954 arr = self.getRandTensor(shape, dtype_list[idx])
955 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700956
957 return consts
958
959 def makeShape(self, rank):
960 if self.targetted_shape:
961 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800962 return np.int32(
963 self.rng.integers(
964 low=self.args.tensor_shape_range[0],
965 high=self.args.tensor_shape_range[1],
966 size=rank,
967 )
968 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700969
970 def setTargetShape(self, shape):
971 self.targetted_shape = shape
972
973 def randInt(self, low=0, high=256):
974 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
975
976 def getRandNumberDType(self, dtype):
977 if dtype == DType.FLOAT:
978 return self.rng.random()
979 elif dtype == DType.BOOL:
980 return self.rng.choice([False, True])
981 elif dtype == DType.INT4:
982 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700983 elif dtype == DType.INT8:
984 low, high = (-127, 128)
985 elif dtype == DType.INT16:
986 low, high = (-32768, 32768)
987 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800988 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -0700989 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800990 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -0700991 # Special size
992 return np.int64(self.rng.integers(low, high, size=1))[0]
993 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800994 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700995
996 return np.int32(self.rng.integers(low, high, size=1))[0]
997
998 def shapeStr(self, shape):
999
1000 sStr = []
1001 # Convert to strings
1002 for i in shape:
1003 sStr.append(str(i))
1004
Kevin Cheng550ccc52021-03-03 11:21:43 -08001005 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001006
1007 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07001008 if isinstance(t, list):
1009 assert len(t) >= 2
1010 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001011 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001012 if t == DType.BOOL:
1013 return "b"
1014 elif t == DType.INT4:
1015 return "i4"
1016 elif t == DType.INT8:
1017 return "i8"
1018 elif t == DType.UINT8:
1019 return "u8"
1020 elif t == DType.INT16:
1021 return "i16"
1022 elif t == DType.INT32:
1023 return "i32"
1024 elif t == DType.INT48:
1025 return "i48"
1026 elif t == DType.FLOAT:
1027 return "float"
1028 else:
1029 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001030
1031 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001032 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08001033 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07001034 return 4
1035 elif t == DType.INT8:
1036 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08001037 elif t == DType.UINT8:
1038 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07001039 elif t == DType.INT16:
1040 return 16
1041 elif t == DType.INT32:
1042 return 32
1043 elif t == DType.INT48:
1044 return 48
1045 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001046 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001047
1048 # Argument generators
1049 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
1050 # Where the string descriptor is used to generate the test name and
1051 # The build_fcn_arg_list is expanded and passed to the operator test
1052 # build function
1053
Kevin Cheng550ccc52021-03-03 11:21:43 -08001054 def build_unary(self, op, a, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001055 result_tens = OutputShaper.unaryOp(self.ser, a)
1056 self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
1057 return result_tens
1058
1059 def build_binary_broadcast(self, op, a, b):
1060 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1061 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1062 return result_tens
1063
1064 def build_binary_nonbroadcast(self, op, a, b):
1065 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
1066 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1067 return result_tens
1068
Kevin Chengaee1fac2020-11-11 13:54:06 -08001069 def build_arithmetic_right_shift(self, op, a, b, round):
1070 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1071
1072 attr = ts.TosaSerializerAttribute()
1073 attr.ArithmeticRightShiftAttribute(round)
1074
1075 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1076 return result_tens
1077
1078 def build_mul(self, op, a, b, shift):
Eric Kunzee5e26762020-10-13 16:11:07 -07001079 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1080
1081 # Special for multiply:
1082 # Force the result to INT32 for INT types
1083 if a.dtype != DType.FLOAT:
1084 result_tens.setDtype(DType.INT32)
1085
Kevin Chengaee1fac2020-11-11 13:54:06 -08001086 attr = ts.TosaSerializerAttribute()
1087 attr.MulAttribute(shift)
1088
1089 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001090 return result_tens
1091
1092 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001093 # Constant size depending on type, random values
1094 if a.dtype == DType.INT16:
1095 table_dtype = DType.INT16
1096 table_arr = self.getRandTensor([513], table_dtype)
1097 else:
1098 assert a.dtype == DType.INT8
1099 table_dtype = DType.INT8
1100 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001101
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001102 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
1103 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001104 self.ser.addOperator(op, [a.name, table_tens.name], [result_tens.name], None)
1105
1106 return result_tens
1107
1108 def build_select(self, op, cond, a, b):
1109
1110 # Replace the cond tensor with a boolean tensor since it probably
1111 # has the wrong dtype
Kevin Cheng989cb052021-04-28 16:29:44 -07001112 t = self.buildPlaceholderTensors([cond.shape], [DType.BOOL])
Eric Kunzee5e26762020-10-13 16:11:07 -07001113 cond = t[0]
1114
1115 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
1116 self.ser.addOperator(op, [cond.name, a.name, b.name], [result_tens.name])
1117
1118 return result_tens
1119
1120 def build_comparison(self, op, a, b):
1121 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
1122 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1123 return result_tens
1124
1125 def build_argmax(self, op, a, axis):
1126 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
1127
1128 attr = ts.TosaSerializerAttribute()
1129 attr.AxisAttribute(axis)
1130
1131 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1132 return result_tens
1133
Kevin Cheng550ccc52021-03-03 11:21:43 -08001134 def build_pool2d(self, op, input, kernel, stride, pad, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001135 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
1136
1137 attr = ts.TosaSerializerAttribute()
1138 attr.Pool2dAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07001139
1140 self.ser.addOperator(op, [input.name], [result_tens.name], attr, qinfo)
1141 return result_tens
1142
1143 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001144 assert len(padding) == 4
1145 result_tens = OutputShaper.conv2dOp(
1146 self.ser, ifm, filter, strides, padding, dilations
1147 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001148
1149 attr = ts.TosaSerializerAttribute()
1150 attr.Conv2dAttribute(padding, strides, dilations)
1151
Kevin Cheng550ccc52021-03-03 11:21:43 -08001152 self.ser.addOperator(
1153 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1154 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001155 return result_tens
1156
Kevin Cheng550ccc52021-03-03 11:21:43 -08001157 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07001158 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001159 ):
1160 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07001161 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
1162
1163 attr = ts.TosaSerializerAttribute()
1164 attr.TransposeConv2DAttribute(outpad, stride, dilation, output_shape)
1165
Kevin Cheng550ccc52021-03-03 11:21:43 -08001166 self.ser.addOperator(
Kevin Cheng989cb052021-04-28 16:29:44 -07001167 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001168 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001169 return result_tens
1170
Kevin Cheng550ccc52021-03-03 11:21:43 -08001171 def build_depthwise_conv2d(
1172 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
1173 ):
1174 result_tens = OutputShaper.depthwiseConv2dOp(
1175 self.ser, ifm, filter, strides, padding, dilations
1176 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001177
1178 attr = ts.TosaSerializerAttribute()
1179 attr.Conv2dAttribute(padding, strides, dilations)
1180
Kevin Cheng550ccc52021-03-03 11:21:43 -08001181 self.ser.addOperator(
1182 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1183 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001184 return result_tens
1185
1186 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
1187 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
1188
Kevin Cheng550ccc52021-03-03 11:21:43 -08001189 self.ser.addOperator(
1190 op, [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
1191 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001192 return result_tens
1193
1194 def build_matmul(self, op, a, b, qinfo):
1195 result_tens = OutputShaper.matmulOp(self.ser, a, b)
1196 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], None, qinfo)
1197 return result_tens
1198
1199 def build_reduce(self, op, a, axis):
1200 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
1201
1202 attr = ts.TosaSerializerAttribute()
1203 attr.AxisAttribute(axis)
1204
1205 self.ser.addOperator(op, [a.name], result_tens.name, attr)
1206 return result_tens
1207
1208 def build_clamp(self, op, a):
1209 result_tens = OutputShaper.unaryOp(self.ser, a)
1210
1211 attr = ts.TosaSerializerAttribute()
1212
1213 # Get two random ints
1214 v = [self.randInt(), self.randInt()]
1215
1216 if a.dtype == DType.FLOAT:
1217 attr.ClampAttribute(0, 0, min(v), max(v))
1218 else:
1219 attr.ClampAttribute(min(v), max(v), 0, 0)
1220
1221 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1222 return result_tens
1223
1224 def build_leaky_relu(self, op, a):
1225 result_tens = OutputShaper.unaryOp(self.ser, a)
1226 attr = ts.TosaSerializerAttribute()
1227
1228 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1229
1230 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1231 return result_tens
1232
1233 # Needs an additional type/input
1234 def build_prelu(self, op, a):
1235 result_tens = OutputShaper.unaryOp(self.ser, a)
1236
1237 self.ser.addOperator(op, [a.name], [result_tens.name])
1238 return result_tens
1239
1240 def build_relun(self, op, a):
1241 result_tens = OutputShaper.unaryOp(self.ser, a)
1242
1243 attr = ts.TosaSerializerAttribute()
1244
1245 if a.dtype == DType.FLOAT:
1246 attr.ReluNAttribute(0, self.getRandNumberDType(a.dtype))
1247 else:
1248 attr.ReluNAttribute(self.getRandNumberDType(a.dtype), 0)
1249
1250 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1251 return result_tens
1252
1253 def build_sigmoid(self, op, a):
1254 result_tens = OutputShaper.unaryOp(self.ser, a)
1255 self.ser.addOperator(op, [a.name], [result_tens.name])
1256 return result_tens
1257
1258 def build_tanh(self, op, a):
1259 result_tens = OutputShaper.unaryOp(self.ser, a)
1260 self.ser.addOperator(op, [a.name], [result_tens.name])
1261 return result_tens
1262
1263 def build_concat(self, op, a, b, axis):
1264 result_tens = OutputShaper.concatOp(self.ser, a, b, axis)
1265
1266 attr = ts.TosaSerializerAttribute()
1267 attr.AxisAttribute(axis)
1268
1269 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1270
1271 def build_pad(self, op, a, padding, qinfo):
1272 result_tens = OutputShaper.padOp(self.ser, a, padding)
1273
1274 # Need to turn the padding array into a TOSA tensor here.
1275 # This is one of the few tensor operands that does not get
1276 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08001277 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07001278
Kevin Cheng550ccc52021-03-03 11:21:43 -08001279 self.ser.addOperator(
1280 op, [a.name, padding_tens.name], [result_tens.name], None, qinfo
1281 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001282
1283 def build_reshape(self, op, a, newShape):
1284 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
1285
1286 attr = ts.TosaSerializerAttribute()
1287 attr.ReshapeAttribute(newShape)
1288
1289 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1290 return result_tens
1291
1292 def build_reverse(self, op, a, axis):
1293 result_tens = OutputShaper.unaryOp(self.ser, a)
1294
1295 attr = ts.TosaSerializerAttribute()
1296 attr.AxisAttribute(axis)
1297
1298 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1299 return result_tens
1300
1301 def build_transpose(self, op, a, perms):
1302 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
1303
Kevin Cheng550ccc52021-03-03 11:21:43 -08001304 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07001305
1306 self.ser.addOperator(op, [a.name, perms_tens.name], [result_tens.name])
1307 return result_tens
1308
1309 def build_slice(self, op, a, begin, size):
1310 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
1311
1312 attr = ts.TosaSerializerAttribute()
1313 attr.SliceAttribute(begin, size)
1314
1315 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1316 return result_tens
1317
1318 def build_tile(self, op, a, multiples):
1319 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
1320
1321 attr = ts.TosaSerializerAttribute()
1322 attr.TileAttribute(multiples)
1323
1324 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1325 return result_tens
1326
Kevin Cheng77d0f762020-11-24 10:26:32 -08001327 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07001328
1329 # Create a new indicies tensor
1330 # here with data that doesn't exceed the dimensions of the values tensor
1331
Kevin Cheng550ccc52021-03-03 11:21:43 -08001332 K = values.shape[1] # K
1333 W = self.randInt(
1334 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1335 ) # W
1336 indicies_arr = np.int32(
1337 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1338 ) # (N, W)
1339 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001340
Kevin Cheng77d0f762020-11-24 10:26:32 -08001341 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07001342
Kevin Cheng77d0f762020-11-24 10:26:32 -08001343 self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001344
1345 return result_tens
1346
Kevin Cheng77d0f762020-11-24 10:26:32 -08001347 def build_scatter(self, op, values_in, input):
1348
1349 # Create a new indicies tensor
1350 # here with data that doesn't exceed the dimensions of the values_in tensor
1351
Kevin Cheng550ccc52021-03-03 11:21:43 -08001352 K = values_in.shape[1] # K
1353 W = input.shape[1] # W
1354 indicies_arr = np.int32(
1355 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1356 ) # (N, W)
1357 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001358
1359 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
1360
Kevin Cheng550ccc52021-03-03 11:21:43 -08001361 self.ser.addOperator(
1362 op, [values_in.name, indicies.name, input.name], [result_tens.name]
1363 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001364
1365 return result_tens
1366
Kevin Cheng550ccc52021-03-03 11:21:43 -08001367 def build_resize(
1368 self,
1369 op,
1370 input,
1371 mode,
1372 stride,
1373 offset,
1374 shift,
1375 stride_fp,
1376 offset_fp,
1377 output_dims,
1378 input_dtype,
1379 output_dtype,
1380 ):
1381 result_tens = OutputShaper.resizeOp(
1382 self.ser,
1383 input,
1384 mode,
1385 stride,
1386 offset,
1387 shift,
1388 stride_fp,
1389 offset_fp,
1390 output_dims,
1391 input_dtype,
1392 output_dtype,
1393 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001394
1395 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001396
Kevin Cheng550ccc52021-03-03 11:21:43 -08001397 attr.ResizeAttribute(
1398 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
1399 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001400
1401 self.ser.addOperator(op, [input.name], [result_tens.name], attr)
1402 return result_tens
1403
1404 def build_identityn(self, op, val, val2):
1405
Kevin Cheng550ccc52021-03-03 11:21:43 -08001406 result_tens = OutputShaper.unaryOp(self.ser, val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001407 result_tens2 = OutputShaper.unaryOp(self.ser, val2)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001408 self.ser.addOperator(
1409 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1410 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001411 return result_tens
1412
1413 def build_placeholder(self, op, val):
1414 # Add an identity op to avoid warning in the reference model
1415 return self.build_unary(Op.IDENTITY, val)
1416
1417 # Type Conversion
1418 def build_cast(self, op, val, out_dtype):
1419 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1420 self.ser.addOperator(op, [val.name], [result_tens.name])
1421 return result_tens
1422
1423 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
1424 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1425
1426 if per_channel:
1427 nc = val.shape[-1]
1428 else:
1429 nc = 1
1430
1431 in_type_width = self.typeWidth(val.dtype)
1432 out_type_width = self.typeWidth(out_dtype)
1433
Kevin Cheng3a478572021-01-22 17:21:02 -08001434 if val.dtype == DType.INT8:
Kevin Cheng989cb052021-04-28 16:29:44 -07001435 input_zp = self.randInt(-128, 127)
Eric Kunzee5e26762020-10-13 16:11:07 -07001436 in_type_width = in_type_width + 1
1437 else:
1438 input_zp = 0
1439
Kevin Cheng3a478572021-01-22 17:21:02 -08001440 if out_dtype == DType.INT8:
Kevin Cheng989cb052021-04-28 16:29:44 -07001441 output_zp = self.randInt(-128, 127)
Eric Kunzee5e26762020-10-13 16:11:07 -07001442 out_type_width = out_type_width + 1
1443 else:
1444 output_zp = 0
1445
1446 # Calculate scale based on:
1447 # scale = a *(2^output_width)/(2^input_width))
1448
1449 a = np.float32(self.rng.random(size=[nc]))
1450 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1451
1452 if scale32:
1453 pass
1454 # Cap the scaling at 2^15 - 1 for scale16
1455 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1456 else:
1457 # Cap the scaling at 2^15 - 1 for scale16
1458 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1459
Kevin Cheng550ccc52021-03-03 11:21:43 -08001460 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001461
1462 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1463 shift_arr = np.int32(np.zeros(shape=[nc]))
1464
1465 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001466 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1467 scale_arr[i], scale32
1468 )
Kevin Chengaee1fac2020-11-11 13:54:06 -08001469 if shift_arr[i] < 2 or shift_arr[i] > 62:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001470 self.ser.setExpectedFailure(True, "OpRescale: invalid shift value")
Eric Kunzee5e26762020-10-13 16:11:07 -07001471
Kevin Cheng550ccc52021-03-03 11:21:43 -08001472 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07001473
1474 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001475 attr.RescaleAttribute(
1476 input_zp,
1477 output_zp,
1478 multiplier_arr,
1479 shift_arr,
1480 scale32,
1481 double_round,
1482 per_channel,
1483 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001484
1485 self.ser.addOperator(op, [val.name], [result_tens.name], attr)
1486 return result_tens
1487
1488 def build_cond_if_const(self, op, then_tens, else_tens, cond):
1489 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1490 # (except for the generated shap) and the condition. Build Then/Else blocks
1491 # and fill them with const nodes for the body.
1492
1493 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001494 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001495
1496 # Make then/else tensors
1497 out_shape = then_tens.shape
1498 then_arr = np.int32(self.rng.integers(0, 255, size=out_shape))
1499 else_arr = np.int32(self.rng.integers(0, 255, size=out_shape))
1500
1501 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001502 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001503
1504 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001505 then_block = "THEN_BLOCK"
1506 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001507 attr = ts.TosaSerializerAttribute()
1508 attr.CondIfAttribute(then_block, else_block)
1509
1510 # Finally, build the op and the two blocks
1511 self.ser.addOperator(op, [cond_tens.name], [result_tens.name], attr)
1512
1513 self.ser.startBasicBlock(then_block)
1514 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001515 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001516 self.ser.addOutputTensor(then_tens)
1517
1518 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001519 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001520 self.ser.addOutputTensor(else_tens)
1521
1522 return result_tens
1523
1524 def build_cond_if_binary(self, op, a, b, cond):
1525 # For cond_if with a binary op in the then/else blocks, take a and b and
1526 # alternately add or subtract them based on the condition
1527
1528 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001529 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001530
Kevin Cheng550ccc52021-03-03 11:21:43 -08001531 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001532 self.ser.currBasicBlock.addOutput(result_tens.name)
1533
1534 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001535 then_block = "THEN_BLOCK"
1536 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001537 attr = ts.TosaSerializerAttribute()
1538 attr.CondIfAttribute(then_block, else_block)
1539
1540 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001541 self.ser.addOperator(
1542 op, [cond_tens.name, a.name, b.name], [result_tens.name], attr
1543 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001544
1545 self.ser.startBasicBlock(then_block)
1546 self.ser.addInputTensor(a)
1547 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001548 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001549 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
1550
1551 self.ser.startBasicBlock(else_block)
1552 self.ser.addInputTensor(a)
1553 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001554 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001555 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
1556
1557 return result_tens
1558
1559 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001560 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001561
Kevin Cheng550ccc52021-03-03 11:21:43 -08001562 cond_block = "COND_BLOCK"
1563 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001564
1565 attr = ts.TosaSerializerAttribute()
1566 attr.WhileLoopAttribute(cond_block, body_block)
1567
1568 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001569 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001570 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001571 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001572
1573 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001574 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1575 a_out = self.ser.addIntermediate(a.shape, a.dtype)
1576 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001577
1578 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001579 self.ser.addOperator(
1580 op,
1581 [iter.name, a.name, acc.name],
1582 [iter_out.name, a_out.name, acc_out.name],
1583 attr,
1584 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001585
1586 # COND block (input: iter, output: cond_tens )
1587 self.ser.startBasicBlock(cond_block)
1588 self.ser.addInputTensor(iter)
1589 self.ser.addInputTensor(a)
1590 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001591 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
1592 cond_tens = self.ser.addOutput([], DType.BOOL)
1593 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001594
1595 # BODY block (input: a, acc, iter, output: a, acc, iter)
1596 # Note that local intermediate tensors need to be declared here for the outputs
1597 self.ser.startBasicBlock(body_block)
1598 self.ser.addInputTensor(iter)
1599 self.ser.addInputTensor(a)
1600 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001601 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
1602 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1603 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001604 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1605 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1606 self.ser.addOutputTensor(iter_body_out)
1607 self.ser.addOutputTensor(a)
1608 self.ser.addOutputTensor(acc_body_out)
1609
1610 return acc_out
1611
Kevin Cheng550ccc52021-03-03 11:21:43 -08001612 def genOpTestList(
1613 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None
1614 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001615
1616 try:
1617 op = self.TOSA_OP_LIST[opName]
1618 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001619 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001620
1621 # Initialize a new random number generator
1622 self.rng = np.random.default_rng(self.random_seed)
1623
Kevin Cheng550ccc52021-03-03 11:21:43 -08001624 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001625
1626 # Generate the lists of arguments
Kevin Cheng550ccc52021-03-03 11:21:43 -08001627 rmin, rmax = op["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001628
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001629 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
1630 default_test_rank_range = range(1, 5)
1631
Eric Kunzee5e26762020-10-13 16:11:07 -07001632 # Test list consists of a tuple of:
1633 # (opName, testNameStr, dtype, shapeList, argumentsList)
1634 testList = []
1635
1636 if not shapeFilter:
1637 shapeFilter = [None]
1638
1639 for r in range(rmin, rmax + 1):
1640
1641 # Filter out the rank?
1642 if rankFilter is not None and r not in rankFilter:
1643 continue
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001644 if rankFilter is None and shapeFilter[0] is None and r not in default_test_rank_range:
1645 continue
Eric Kunzee5e26762020-10-13 16:11:07 -07001646
Kevin Cheng550ccc52021-03-03 11:21:43 -08001647 for t in op["types"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001648
1649 # Filter tests based on dtype?
1650 if dtypeFilter is not None:
1651 if t not in dtypeFilter:
1652 continue
1653
1654 # Create the placeholder and const tensors
1655 for shape in shapeFilter:
1656 # A None shape chooses a random shape of a given rank
1657
1658 # Filter out by rank
1659 if shape is not None and len(shape) != r:
1660 continue
1661
1662 self.setTargetShape(shape)
1663 shapeList = tgen_fcn(self, op, r)
1664
1665 shapeStr = self.shapeStr(shapeList[0])
1666 typeStr = self.typeStr(t)
1667
1668 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
1669 argList = []
1670 if agen_fcn:
1671 argList = agen_fcn(self, opName, shapeList, t)
1672 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001673 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07001674
1675 for argStr, args in argList:
1676 if argStr:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001677 testStr = "{}_{}_{}_{}".format(
1678 opName, shapeStr, typeStr, argStr
1679 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001680 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001681 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001682
1683 testList.append((opName, testStr, t, shapeList, args))
1684
1685 return testList
1686
Kevin Cheng989cb052021-04-28 16:29:44 -07001687 def serializeTest(self, opName, testStr, dtype_or_dtypeList, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07001688 try:
1689 op = self.TOSA_OP_LIST[opName]
1690 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001691 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001692
1693 # Create a serializer
1694 self.createSerializer(opName, testStr)
1695
Kevin Cheng550ccc52021-03-03 11:21:43 -08001696 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
1697 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07001698 num_operands = pCount + cCount
1699
1700 if isinstance(dtype_or_dtypeList, list):
1701 dtypeList = dtype_or_dtypeList
1702 else:
1703 dtypeList = [dtype_or_dtypeList] * (num_operands)
1704
1705 assert (
1706 len(shapeList) == num_operands
1707 ), "shapeList length {} must match number of operands {}".format(
1708 len(shapeList), num_operands
1709 )
1710 assert (
1711 len(dtypeList) == num_operands
1712 ), "dtypeList length {} must match number of operands {}".format(
1713 len(dtypeList), num_operands
1714 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001715
1716 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001717 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001718 except KeyError:
1719 qgen = None
1720
1721 # Build the random tensor operands and the test
1722 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08001723
1724 # If test is ArithmeticRightShift, force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001725 if op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
1726 assert (
1727 pCount == 2 and cCount == 0
1728 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08001729
1730 placeholders = []
1731 for idx, shape in enumerate(shapeList[:]):
1732 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07001733 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001734 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001735 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001736 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001737 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001738 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
1739 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001740 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08001741 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001742 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07001743 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001744
1745 tens.extend(placeholders)
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001746 elif op["op"] == Op.DIV:
1747 assert (
1748 pCount == 2 and cCount == 0
1749 ), "Op.Div must have 2 placeholders, 0 consts"
1750
1751 placeholders = []
1752
1753 # Two invalid cases for Op.DIV:
1754 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07001755 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001756 while True:
1757 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
1758 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
1759
1760 if (divisor_arr == 0).any():
1761 continue
1762
Kevin Cheng47315e12021-05-13 17:41:28 -07001763 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001764 continue
1765
1766 break
1767
1768 placeholders.append(
1769 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1770 )
1771 placeholders.append(
1772 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1773 )
1774
1775 tens.extend(placeholders)
1776 elif op["op"] == Op.MUL:
1777 assert (
1778 pCount == 2 and cCount == 0
1779 ), "Op.MUL must have 2 placeholders, 0 consts"
1780
1781 if dtypeList[0] == DType.FLOAT:
1782 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
1783 else:
1784 placeholders = []
1785
1786 # Make sure multiply result in int32 range
1787 shift = testArgs[0]
1788 if dtypeList[0] == DType.INT8:
1789 num_bits = 8
1790 elif dtypeList[0] == DType.INT16:
1791 num_bits = 16
1792 elif dtypeList[0] == DType.INT32:
1793 num_bits = 32
1794 else:
1795 raise Exception("OpMul: invalid input dtype")
1796
1797 for idx, shape in enumerate(shapeList[:]):
1798 low = -(2 ** (num_bits - 1))
1799 high = (2 ** (num_bits - 1)) - 1
1800
1801 a_arr = np.int32(
1802 self.rng.integers(low=low, high=high, size=shapeList[0])
1803 )
1804 b_arr = np.int32(
1805 self.rng.integers(low=low, high=high, size=shapeList[1])
1806 )
1807
1808 i = 0
1809 while True:
1810
1811 a_arr_64 = a_arr.astype(np.int64)
1812 b_arr_64 = b_arr.astype(np.int64)
1813
1814 if shift > 0:
1815 rounding = 1 << (shift - 1)
1816 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
1817 else:
1818 result_arr = a_arr_64 * b_arr_64
1819
1820 if (result_arr > -(2 ** 31)).all() and (
1821 result_arr <= ((2 ** 31) - 1)
1822 ).all():
1823 break
1824
1825 i = i + 1
1826 a_arr = a_arr // 2
1827 b_arr = b_arr // 2
1828
1829 placeholders.append(
1830 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1831 )
1832 placeholders.append(
1833 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1834 )
1835
1836 tens.extend(placeholders)
Kevin Chengaee1fac2020-11-11 13:54:06 -08001837 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001838 tens.extend(
1839 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
1840 )
1841 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001842
1843 if qgen is not None:
Kevin Cheng989cb052021-04-28 16:29:44 -07001844 qinfo = qgen(self, op, dtypeList[0])
Eric Kunzee5e26762020-10-13 16:11:07 -07001845 else:
1846 qinfo = None
1847
1848 try:
1849 if qinfo is not None:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001850 resultName = build_fcn(self, op["op"], *tens, *testArgs, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001851 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001852 resultName = build_fcn(self, op["op"], *tens, *testArgs)
Eric Kunzee5e26762020-10-13 16:11:07 -07001853 except TypeError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001854 print(
1855 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
1856 build_fcn, tens, testArgs
1857 )
1858 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001859 raise e
1860
1861 # Save the serialized test
Kevin Cheng550ccc52021-03-03 11:21:43 -08001862 self.serialize("test")
Eric Kunzee5e26762020-10-13 16:11:07 -07001863
1864 def createDynamicOpLists(self):
1865
1866 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng550ccc52021-03-03 11:21:43 -08001867 KERNELS = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07001868
1869 for k in KERNELS:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001870 testName = "conv2d_{}x{}".format(k[0], k[1])
1871 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
1872 self.TOSA_OP_LIST[testName]["filter"] = k
1873 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001874
Kevin Cheng550ccc52021-03-03 11:21:43 -08001875 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
1876 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1877 "depthwise_conv2d_TEMPLATE"
1878 ].copy()
1879 self.TOSA_OP_LIST[testName]["filter"] = k
1880 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001881
Kevin Cheng550ccc52021-03-03 11:21:43 -08001882 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
1883 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1884 "transpose_conv2d_TEMPLATE"
1885 ].copy()
1886 self.TOSA_OP_LIST[testName]["filter"] = k
1887 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001888
1889 # Delete any templates after having created any dynamic ops
1890 # This is a two-pass operation because it's bad practice to delete
1891 # keys from dictionaries while iterating
1892 keyList = []
1893 for k in self.TOSA_OP_LIST:
1894 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001895 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07001896 keyList.append(k)
1897 continue
1898 except KeyError:
1899 pass
1900
1901 for k in keyList:
1902 del self.TOSA_OP_LIST[k]
1903
1904 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001905 """Fill in default fields for ops if they aren't already specified.
1906 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07001907 for op in self.TOSA_OP_LIST:
1908
1909 # Required fields
1910 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001911 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001912 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001913 raise Exception(
1914 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
1915 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001916
1917 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001918 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001919 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001920 raise Exception(
1921 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
1922 op
1923 )
1924 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001925
1926 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001927 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001928 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001929 raise Exception(
1930 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
1931 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001932
1933 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001934 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001935 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001936 raise Exception(
1937 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
1938 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001939
1940 # Put in default rank range, if missing
1941 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001942 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001943 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001944 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07001945
1946 # Tensor operator list
1947 # 'op': op name
1948 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08001949 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
1950 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001951 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
1952 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08001953 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07001954
Kevin Cheng550ccc52021-03-03 11:21:43 -08001955 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
1956 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07001957
Kevin Cheng550ccc52021-03-03 11:21:43 -08001958 TYPE_BOOL = [DType.BOOL]
1959 TYPE_FI32 = [DType.FLOAT, DType.INT32]
1960 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
1961 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07001962
Kevin Cheng550ccc52021-03-03 11:21:43 -08001963 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07001964
Kevin Cheng989cb052021-04-28 16:29:44 -07001965 TYPE_CONV2D = [
1966 [DType.INT8, DType.INT8, DType.INT32],
1967 [DType.INT16, DType.INT8, DType.INT48],
1968 DType.FLOAT,
1969 ]
1970
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001971 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07001972
1973 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08001974 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08001975 "argmax": {
1976 "op": Op.ARGMAX,
1977 "operands": (1, 0),
1978 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
1979 "types": TYPE_NARROW_INT_FP,
1980 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001981 "avg_pool2d": {
1982 "op": Op.AVG_POOL2D,
1983 "operands": (1, 0),
1984 "rank": (4, 4),
1985 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
1986 "qgen": TosaQuantGen.qgUnary,
1987 "types": TYPE_NARROW_INT_FP,
1988 },
Eric Kunzee5e26762020-10-13 16:11:07 -07001989 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08001990 "conv2d_TEMPLATE": {
1991 "op": Op.CONV2D,
1992 "operands": (1, 2),
1993 "rank": (4, 4),
1994 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
1995 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07001996 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08001997 "template": True,
1998 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001999 # Conv3d TBD
Eric Kunzee5e26762020-10-13 16:11:07 -07002000 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002001 "depthwise_conv2d_TEMPLATE": {
2002 "op": Op.DEPTHWISE_CONV2D,
2003 "operands": (1, 2),
2004 "filter": [1, 1],
2005 "rank": (4, 4),
2006 "build_fcn": (
2007 build_depthwise_conv2d,
2008 TosaTensorGen.tgDepthwiseConv2D,
2009 TosaArgGen.agConv2D,
2010 ),
2011 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002012 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002013 "template": True,
2014 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002015 "fully_connected": {
2016 "op": Op.FULLY_CONNECTED,
2017 "operands": (1, 2),
2018 "rank": (2, 2),
2019 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
2020 "qgen": TosaQuantGen.qgConv,
2021 "types": TYPE_CONV2D,
2022 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002023 "matmul": {
2024 "op": Op.MATMUL,
2025 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002026 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08002027 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
2028 "qgen": TosaQuantGen.qgMatmul,
2029 "types": TYPE_NARROW_INT_FP,
2030 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002031 "max_pool2d": {
2032 "op": Op.MAX_POOL2D,
2033 "operands": (1, 0),
2034 "rank": (4, 4),
2035 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2036 "types": TYPE_NARROW_INT_FP,
2037 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002038 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002039 "transpose_conv2d_TEMPLATE": {
2040 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002041 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002042 "rank": (4, 4),
2043 "build_fcn": (
2044 build_transpose_conv2d,
2045 TosaTensorGen.tgTransposeConv2D,
2046 TosaArgGen.agTransposeConv2D,
2047 ),
2048 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002049 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002050 "template": True,
2051 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002052 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002053 "clamp": {
2054 "op": Op.CLAMP,
2055 "operands": (1, 0),
2056 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
2057 "types": TYPE_NARROW_INT_FP,
2058 },
2059 "relun": {
2060 "op": Op.RELUN,
2061 "operands": (1, 0),
2062 "build_fcn": (build_relun, TosaTensorGen.tgBasic, None),
2063 "types": TYPE_FI32,
2064 },
2065 "sigmoid": {
2066 "op": Op.SIGMOID,
2067 "operands": (1, 0),
2068 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
2069 "types": TYPE_FP,
2070 },
2071 "tanh": {
2072 "op": Op.TANH,
2073 "operands": (1, 0),
2074 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
2075 "types": TYPE_FP,
2076 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002077 # Elementwise Binary Operators
2078 "add": {
2079 "op": Op.ADD,
2080 "operands": (2, 0),
2081 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2082 "types": TYPE_FI32,
2083 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002084 "arithmetic_right_shift": {
2085 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2086 "operands": (2, 0),
2087 "build_fcn": (
2088 build_arithmetic_right_shift,
2089 TosaTensorGen.tgBroadcastFuzz,
2090 TosaArgGen.agArithmeticRightShift,
2091 ),
2092 "types": TYPE_INT,
2093 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002094 "bitwise_and": {
2095 "op": Op.BITWISE_AND,
2096 "operands": (2, 0),
2097 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2098 "types": TYPE_INT,
2099 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002100 "bitwise_or": {
2101 "op": Op.BITWISE_OR,
2102 "operands": (2, 0),
2103 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2104 "types": TYPE_INT,
2105 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002106 "bitwise_xor": {
2107 "op": Op.BITWISE_XOR,
2108 "operands": (2, 0),
2109 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2110 "types": TYPE_INT,
2111 },
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002112 "div": {
2113 "op": Op.DIV,
2114 "operands": (2, 0),
2115 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2116 "types": [DType.INT32],
2117 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002118 "logical_and": {
2119 "op": Op.LOGICAL_AND,
2120 "operands": (2, 0),
2121 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2122 "types": TYPE_BOOL,
2123 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002124 "logical_left_shift": {
2125 "op": Op.LOGICAL_LEFT_SHIFT,
2126 "operands": (2, 0),
2127 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2128 "types": TYPE_INT,
2129 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002130 "logical_right_shift": {
2131 "op": Op.LOGICAL_RIGHT_SHIFT,
2132 "operands": (2, 0),
2133 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2134 "types": TYPE_INT,
2135 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002136 "logical_or": {
2137 "op": Op.LOGICAL_OR,
2138 "operands": (2, 0),
2139 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2140 "types": TYPE_BOOL,
2141 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002142 "logical_xor": {
2143 "op": Op.LOGICAL_XOR,
2144 "operands": (2, 0),
2145 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2146 "types": TYPE_BOOL,
2147 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002148 "maximum": {
2149 "op": Op.MAXIMUM,
2150 "operands": (2, 0),
2151 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2152 "types": TYPE_FI32,
2153 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002154 "minimum": {
2155 "op": Op.MINIMUM,
2156 "operands": (2, 0),
2157 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2158 "types": TYPE_FI32,
2159 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002160 "mul": {
2161 "op": Op.MUL,
2162 "operands": (2, 0),
2163 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
2164 "types": TYPE_INT_FP,
2165 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002166 "pow": {
2167 "op": Op.POW,
2168 "operands": (2, 0),
2169 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
2170 "types": TYPE_FP,
2171 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002172 "sub": {
2173 "op": Op.SUB,
2174 "operands": (2, 0),
2175 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2176 "types": TYPE_FI32,
2177 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002178 "table": {
2179 "op": Op.TABLE,
2180 # Use the automatic generation functions to create the input array
2181 # but create the table tensor in the build function, as it may be
2182 # a different type from the input
2183 "operands": (1, 0),
2184 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002185 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08002186 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002187 # Elementwise Unary operators
2188 "abs": {
2189 "op": Op.ABS,
2190 "operands": (1, 0),
2191 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2192 "types": TYPE_FI32,
2193 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002194 "bitwise_not": {
2195 "op": Op.BITWISE_NOT,
2196 "operands": (1, 0),
2197 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2198 "types": TYPE_INT,
2199 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002200 "ceil": {
2201 "op": Op.CEIL,
2202 "operands": (1, 0),
2203 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2204 "types": TYPE_FP,
2205 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002206 "clz": {
2207 "op": Op.CLZ,
2208 "operands": (1, 0),
2209 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2210 "types": [DType.INT32],
2211 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002212 "exp": {
2213 "op": Op.EXP,
2214 "operands": (1, 0),
2215 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2216 "types": TYPE_FP,
2217 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002218 "floor": {
2219 "op": Op.FLOOR,
2220 "operands": (1, 0),
2221 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2222 "types": TYPE_FP,
2223 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002224 "log": {
2225 "op": Op.LOG,
2226 "operands": (1, 0),
2227 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2228 "types": TYPE_FP,
2229 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002230 "logical_not": {
2231 "op": Op.LOGICAL_NOT,
2232 "operands": (1, 0),
2233 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2234 "types": TYPE_BOOL,
2235 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002236 "negate": {
2237 "op": Op.NEGATE,
2238 "operands": (1, 0),
2239 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2240 "qgen": TosaQuantGen.qgUnary,
2241 "types": TYPE_INT_FP,
2242 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002243 "reciprocal": {
2244 "op": Op.RECIPROCAL,
2245 "operands": (1, 0),
2246 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2247 "types": TYPE_FP,
2248 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002249 "rsqrt": {
2250 "op": Op.RSQRT,
2251 "operands": (1, 0),
2252 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2253 "types": TYPE_FP,
2254 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002255 # Elementwise Ternary operators
2256 "select": {
2257 "op": Op.SELECT,
2258 "operands": (3, 0),
2259 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
2260 "types": TYPE_FIB,
2261 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002262 # Comparison operators
2263 "equal": {
2264 "op": Op.EQUAL,
2265 "operands": (2, 0),
2266 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2267 "types": TYPE_FI32,
2268 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002269 "greater_equal": {
2270 "op": Op.GREATER_EQUAL,
2271 "operands": (2, 0),
2272 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2273 "types": TYPE_FI32,
2274 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002275 "greater": {
2276 "op": Op.GREATER,
2277 "operands": (2, 0),
2278 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2279 "types": TYPE_FI32,
2280 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002281 # Reduction operators
2282 "reduce_all": {
2283 "op": Op.REDUCE_ALL,
2284 "operands": (1, 0),
2285 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2286 "types": TYPE_BOOL,
2287 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002288 "reduce_any": {
2289 "op": Op.REDUCE_ANY,
2290 "operands": (1, 0),
2291 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2292 "types": TYPE_BOOL,
2293 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002294 "reduce_max": {
2295 "op": Op.REDUCE_MAX,
2296 "operands": (1, 0),
2297 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2298 "types": TYPE_INT_FP,
2299 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002300 "reduce_min": {
2301 "op": Op.REDUCE_MAX,
2302 "operands": (1, 0),
2303 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2304 "types": TYPE_INT_FP,
2305 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002306 "reduce_product": {
2307 "op": Op.REDUCE_PRODUCT,
2308 "operands": (1, 0),
2309 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2310 "types": TYPE_FP,
2311 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002312 "reduce_sum": {
2313 "op": Op.REDUCE_SUM,
2314 "operands": (1, 0),
2315 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2316 "types": TYPE_FI32,
2317 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002318 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002319 "concat": {
2320 "op": Op.CONCAT,
2321 "operands": (2, 0),
2322 "build_fcn": (build_concat, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2323 "types": TYPE_FIB,
2324 },
2325 "pad": {
2326 "op": Op.PAD,
2327 "operands": (1, 0),
2328 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
2329 "qgen": TosaQuantGen.qgPad,
2330 "types": TYPE_FIB,
2331 },
2332 "reshape": {
2333 "op": Op.RESHAPE,
2334 "operands": (1, 0),
2335 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
2336 "types": TYPE_FIB,
2337 },
2338 "reverse": {
2339 "op": Op.REVERSE,
2340 "operands": (1, 0),
2341 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2342 "types": TYPE_FIB,
2343 },
2344 "slice": {
2345 "op": Op.SLICE,
2346 "operands": (1, 0),
2347 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
2348 "types": TYPE_FIB,
2349 },
2350 "tile": {
2351 "op": Op.TILE,
2352 "operands": (1, 0),
2353 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
2354 "types": TYPE_FIB,
2355 },
2356 "transpose": {
2357 "op": Op.TRANSPOSE,
2358 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01002359 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002360 "build_fcn": (
2361 build_transpose,
2362 TosaTensorGen.tgBasic,
2363 TosaArgGen.agTranspose,
2364 ),
2365 "types": TYPE_FIB,
2366 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002367 # Data nodes
2368 "const": {
2369 "op": Op.CONST,
2370 "operands": (1, 0),
2371 "build_fcn": (build_placeholder, TosaTensorGen.tgBasic, None),
2372 "types": TYPE_FIB,
2373 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002374 "identity": {
2375 "op": Op.IDENTITY,
2376 "operands": (1, 0),
2377 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2378 "types": TYPE_FIB,
2379 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002380 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08002381 "gather": {
2382 "op": Op.GATHER,
2383 # Only specify 'values' tensor here. 'indices' is generated in op building stage
2384 "operands": (1, 0),
2385 "rank": (3, 3),
2386 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
2387 "types": TYPE_INT_FP,
2388 },
2389 "scatter": {
2390 "op": Op.SCATTER,
2391 # Only specify 'values_in' tensor here.
2392 #'indices' and 'input' are generated in op building stage
2393 "operands": (2, 0),
2394 "rank": (3, 3),
2395 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
2396 "types": TYPE_INT_FP,
2397 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002398 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08002399 "resize": {
2400 "op": Op.RESIZE,
2401 "operands": (1, 0),
2402 "rank": (4, 4),
2403 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
2404 "types": [DType.INT8, DType.INT16, DType.FLOAT],
2405 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002406 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08002407 "cast": {
2408 "op": Op.CAST,
2409 "operands": (1, 0),
2410 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
2411 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
2412 },
2413 "rescale": {
2414 "op": Op.RESCALE,
2415 "operands": (1, 0),
2416 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
2417 "types": [DType.INT8, DType.INT16, DType.INT32, DType.INT48],
2418 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002419 # Custom
2420 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08002421 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07002422 # Two varients of cond_if, one that generates one of two constant tensors (no
2423 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
2424 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002425 "cond_if_const": {
2426 "op": Op.COND_IF,
2427 "operands": (0, 2),
2428 "build_fcn": (
2429 build_cond_if_const,
2430 TosaTensorGen.tgBasic,
2431 TosaArgGen.agCondIf,
2432 ),
2433 "types": [DType.BOOL],
2434 },
2435 "cond_if_binary": {
2436 "op": Op.COND_IF,
2437 "operands": (2, 0),
2438 "build_fcn": (
2439 build_cond_if_binary,
2440 TosaTensorGen.tgBasic,
2441 TosaArgGen.agCondIf,
2442 ),
2443 "types": TYPE_FI32,
2444 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002445 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002446 "while_loop": {
2447 "op": Op.WHILE_LOOP,
2448 "operands": (0, 1),
2449 "build_fcn": (
2450 build_while_loop,
2451 TosaTensorGen.tgBasic,
2452 TosaArgGen.agWhileLoop,
2453 ),
2454 "types": [DType.INT32],
2455 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002456 }
2457
Kevin Cheng550ccc52021-03-03 11:21:43 -08002458
Eric Kunzee5e26762020-10-13 16:11:07 -07002459class OutputShaper:
2460 # Methods in this class compute the expected output shape and datatype
2461 # for common classes of operations
2462 def __init__(self):
2463 pass
2464
2465 # These methods return arguments that can be used for
2466 # creating a new output tensor
2467 @staticmethod
2468 def binaryBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002469 assert len(a.shape) == len(b.shape)
2470 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002471
2472 shape = []
2473 for i in range(len(a.shape)):
2474 if a.shape[i] == 1:
2475 shape.append(b.shape[i])
2476 else:
2477 shape.append(a.shape[i])
2478
Kevin Cheng550ccc52021-03-03 11:21:43 -08002479 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002480
2481 @staticmethod
2482 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002483 assert len(a.shape) == len(b.shape)
2484 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002485
2486 shape = []
2487 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002488 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07002489 shape.append(a.shape[i])
2490
Kevin Cheng550ccc52021-03-03 11:21:43 -08002491 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002492
2493 @staticmethod
2494 def unaryOp(ser, a):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002495 return ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002496
2497 @staticmethod
2498 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002499 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
2500 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002501
2502 shape = []
2503 for i in range(len(a.shape)):
2504 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
2505
Kevin Cheng550ccc52021-03-03 11:21:43 -08002506 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002507
2508 @staticmethod
2509 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002510 assert len(a.shape) == len(b.shape)
2511 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002512
2513 # Do broadcast
2514 shape = []
2515 for i in range(len(a.shape)):
2516 if a.shape[i] == 1:
2517 shape.append(b.shape[i])
2518 else:
2519 shape.append(a.shape[i])
2520
2521 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08002522 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07002523
2524 @staticmethod
2525 def reduceOp(ser, a, axis):
2526
2527 shape = a.shape.copy()
2528
2529 shape[axis] = 1
2530
Kevin Cheng550ccc52021-03-03 11:21:43 -08002531 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002532
2533 @staticmethod
2534 def argmaxOp(ser, a, axis):
2535 shape = a.shape.copy()
2536 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002537 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002538
2539 @staticmethod
2540 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
2541
2542 # IFM: NHWC
2543 # Filter: OHWI
2544 # OFM: NHWC
2545
2546 if len(padding) == 2:
2547 # Expand padding to 4 parameters in the case of transpose_conv2d
2548 # From H,W to T,B,L,R
2549 padding = [padding[0], padding[0], padding[1], padding[1]]
2550
Kevin Cheng550ccc52021-03-03 11:21:43 -08002551 h = (
2552 ifm.shape[1]
2553 - filter.shape[1]
2554 - (filter.shape[1] - 1) * (dilations[0] - 1)
2555 + padding[0]
2556 + padding[1]
2557 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002558
Kevin Cheng550ccc52021-03-03 11:21:43 -08002559 w = (
2560 ifm.shape[2]
2561 - filter.shape[2]
2562 - (filter.shape[2] - 1) * (dilations[1] - 1)
2563 + padding[2]
2564 + padding[3]
2565 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002566
2567 if h <= 0 or w <= 0:
2568 # Invalid test parameters?
2569 h = 0
2570 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002571 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002572
2573 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
2574
Kevin Cheng3a478572021-01-22 17:21:02 -08002575 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002576 out_dtype = DType.INT32
2577 elif ifm.dtype == DType.INT16:
2578 out_dtype = DType.INT48
2579 elif ifm.dtype == DType.FLOAT:
2580 out_dtype = DType.FLOAT
2581 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002582 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002583
Kevin Cheng550ccc52021-03-03 11:21:43 -08002584 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002585
2586 @staticmethod
2587 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
2588 # IFM: NHWC
2589 # Filter: HWCM
2590 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08002591 h = (
2592 ifm.shape[1]
2593 - filter.shape[0]
2594 - (filter.shape[0] - 1) * (dilations[0] - 1)
2595 + padding[0]
2596 + padding[1]
2597 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002598
Kevin Cheng550ccc52021-03-03 11:21:43 -08002599 w = (
2600 ifm.shape[2]
2601 - filter.shape[1]
2602 - (filter.shape[1] - 1) * (dilations[1] - 1)
2603 + padding[2]
2604 + padding[3]
2605 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002606
2607 if h <= 0 or w <= 0:
2608 # Invalid test parameters?
2609 h = 0
2610 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002611 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002612
2613 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
2614
Kevin Cheng3a478572021-01-22 17:21:02 -08002615 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002616 out_dtype = DType.INT32
2617 elif ifm.dtype == DType.INT16:
2618 out_dtype = DType.INT48
2619 elif ifm.dtype == DType.FLOAT:
2620 out_dtype = DType.FLOAT
2621 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002622 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002623
Kevin Cheng550ccc52021-03-03 11:21:43 -08002624 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002625
2626 @staticmethod
2627 def pool2dOp(ser, ifm, kernel, stride, pad):
2628 # input: NHWC
2629 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
2630 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
2631
2632 if h <= 0 or w <= 0:
2633 # Invalid test parameters?
2634 h = 0
2635 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002636 ser.setExpectedFailure(True, "Invalid combination of pooling parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002637
2638 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002639 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002640
2641 @staticmethod
2642 def fullyConnectedOp(ser, input, filter):
2643 # input: N, IC
2644 # filter: OC, IC
2645 # output: N, OC
2646
2647 output_shape = [input.shape[0], filter.shape[0]]
2648
Kevin Cheng3a478572021-01-22 17:21:02 -08002649 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002650 out_dtype = DType.INT32
2651 elif input.dtype == DType.INT16:
2652 out_dtype = DType.INT48
2653 elif input.dtype == DType.FLOAT:
2654 out_dtype = DType.FLOAT
2655 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002656 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002657
Kevin Cheng550ccc52021-03-03 11:21:43 -08002658 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002659
2660 @staticmethod
2661 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07002662 # a: N, H, C
2663 # b: N, C, W
2664 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07002665
Kevin Cheng2d60f002021-06-09 14:18:32 -07002666 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002667
Kevin Cheng3a478572021-01-22 17:21:02 -08002668 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002669 out_dtype = DType.INT32
2670 elif a.dtype == DType.INT16:
2671 out_dtype = DType.INT48
2672 elif a.dtype == DType.FLOAT:
2673 out_dtype = DType.FLOAT
2674 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002675 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002676
Kevin Cheng550ccc52021-03-03 11:21:43 -08002677 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002678
2679 @staticmethod
2680 def concatOp(ser, a, b, axis):
2681
2682 output_shape = a.shape.copy()
2683 output_shape[axis] = a.shape[axis] + b.shape[axis]
2684
Kevin Cheng550ccc52021-03-03 11:21:43 -08002685 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002686
2687 @staticmethod
2688 def padOp(ser, a, padding):
2689
2690 output_shape = a.shape.copy()
2691
2692 for i in range(len(output_shape)):
2693 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
2694
Kevin Cheng550ccc52021-03-03 11:21:43 -08002695 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002696
2697 @staticmethod
2698 def reshapeOp(ser, a, shape):
2699 output_shape = shape.copy()
2700
2701 totalElements = 1
2702 for i in a.shape:
2703 totalElements *= i
2704
2705 # If there are any -1 elements, figure out what that dimension must be
2706 totalOutputElements = 1
2707 for i in output_shape:
2708 if i != -1:
2709 totalOutputElements *= i
2710
2711 # And fill it in
2712 for i in range(len(output_shape)):
2713 if output_shape[i] == -1:
2714 output_shape[i] = totalElements // totalOutputElements
2715
Kevin Cheng550ccc52021-03-03 11:21:43 -08002716 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002717
2718 @staticmethod
2719 def sliceOp(ser, a, begin, size):
2720
2721 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002722 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002723
2724 @staticmethod
2725 def tileOp(ser, a, multiples):
2726
2727 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002728 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002729
2730 for i in range(len(output_shape)):
2731 output_shape[i] = a.shape[i] * multiples[i]
2732
Kevin Cheng550ccc52021-03-03 11:21:43 -08002733 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002734
2735 @staticmethod
2736 def transposeOp(ser, a, perms):
2737 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002738 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002739
2740 for i in range(len(output_shape)):
2741 output_shape[i] = a.shape[perms[i]]
2742
Kevin Cheng550ccc52021-03-03 11:21:43 -08002743 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002744
2745 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08002746 def gatherOp(ser, values, indices):
2747 assert len(values.shape) == 3
2748 assert len(indices.shape) == 2
2749 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07002750
Kevin Cheng77d0f762020-11-24 10:26:32 -08002751 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
2752
Kevin Cheng550ccc52021-03-03 11:21:43 -08002753 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002754
2755 @staticmethod
2756 def scatterOp(ser, values_in, indices, input):
2757 assert len(values_in.shape) == 3
2758 assert len(indices.shape) == 2
2759 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08002760 assert values_in.shape[0] == indices.shape[0] # N
2761 assert input.shape[1] == indices.shape[1] # W
2762 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08002763
2764 output_shape = values_in.shape
2765
Kevin Cheng550ccc52021-03-03 11:21:43 -08002766 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002767
2768 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002769 def tableOp(ser, input, table_dtype):
2770 # Same shape as the input, but dtype dependent on table dtype
2771 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
2772 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
2773 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002774
2775 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08002776 def resizeOp(
2777 ser,
2778 input,
2779 mode,
2780 stride,
2781 offset,
2782 shift,
2783 stride_fp,
2784 offset_fp,
2785 output_dims,
2786 input_dtype,
2787 output_dtype,
2788 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002789
2790 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
2791
Kevin Cheng77d0f762020-11-24 10:26:32 -08002792 if input_dtype == DType.FLOAT:
2793 if stride_fp[0] <= 0 or stride_fp[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002794 ser.setExpectedFailure(True, "Negative or zero stride")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002795 else:
2796 if stride[0] <= 0 or stride[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002797 ser.setExpectedFailure(True, "Negative or zero stride")
Eric Kunzee5e26762020-10-13 16:11:07 -07002798
Kevin Chengaee1fac2020-11-11 13:54:06 -08002799 if mode == ResizeMode.BILINEAR:
2800 if input_dtype == DType.INT8:
2801 if output_dtype != DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002802 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002803 elif input_dtype == DType.INT16:
2804 if output_dtype != DType.INT48:
Kevin Cheng989cb052021-04-28 16:29:44 -07002805 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002806 elif input_dtype == DType.FLOAT:
2807 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002808 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002809 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002810 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002811
2812 elif mode == ResizeMode.NEAREST:
2813 if input_dtype == DType.INT8:
2814 if output_dtype != DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002815 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002816 elif input_dtype == DType.INT16:
2817 if output_dtype != DType.INT16:
Kevin Cheng989cb052021-04-28 16:29:44 -07002818 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002819 elif input_dtype == DType.FLOAT:
2820 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002821 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002822 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002823 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002824
2825 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002826 ser.setExpectedFailure(true, "Invalid resize mode")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002827
Kevin Cheng550ccc52021-03-03 11:21:43 -08002828 return ser.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002829
2830 @staticmethod
2831 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002832 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002833
2834 @staticmethod
2835 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08002836 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002837 out_dtype = DType.INT32
2838 elif ifm.dtype == DType.INT16:
2839 out_dtype = DType.INT48
2840 elif ifm.dtype == DType.FLOAT:
2841 out_dtype = DType.FLOAT
2842 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002843 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002844
2845 if output_shape[1] <= 0 or output_shape[2] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002846 ser.setExpectedFailure(True, "Negative output shape")
Eric Kunzee5e26762020-10-13 16:11:07 -07002847
Kevin Cheng550ccc52021-03-03 11:21:43 -08002848 return ser.addOutput(output_shape, out_dtype)