blob: c05abc01d22044d87d3b1e6e1bb676adf8cfce7e [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
35
Kevin Cheng550ccc52021-03-03 11:21:43 -080036# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
37parent_dir = os.path.dirname(os.path.realpath(__file__))
38sys.path.append(
39 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
40)
Eric Kunzee5e26762020-10-13 16:11:07 -070041import tosa_serializer as ts
42from tosa_serializer import *
43import tosa
44
45# Convenience variables to the flatc-generated types that should be enums, but aren't
46DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080047Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070048ResizeMode = tosa.ResizeMode.ResizeMode()
49
Kevin Cheng550ccc52021-03-03 11:21:43 -080050
Eric Kunzee5e26762020-10-13 16:11:07 -070051class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080052 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
53
Eric Kunzee5e26762020-10-13 16:11:07 -070054 def __init__(self):
55 pass
56
57 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010058 def getQinfo(testGen, dtype):
59 if dtype == DType.INT8:
60 return testGen.randInt(-128, 128)
61 if dtype == DType.UINT8:
62 return testGen.randInt(0, 256)
63 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070064
65 @staticmethod
66 def qgUnary(testGen, op, dtype):
67 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010068 qinfo.UnaryQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
69 TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070070 return qinfo
71
72 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010073 def qgConv(testGen, op, dtype_or_dtypeList):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010075 if isinstance(dtype_or_dtypeList, list):
76 # a list of [input, weights, accumulator] dtypes
77 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070078 else:
Les Bell30e46802021-07-23 09:43:31 +010079 # an int, [input, weights, accumulator] dtypes are the same
80 dtypeList = [dtype_or_dtypeList] * 3
81 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
82 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
83 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -070084 return qinfo
85
86 @staticmethod
87 def qgMatmul(testGen, op, dtype):
88 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010089 qinfo.MatMulQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
90 TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070091 return qinfo
92
93 @staticmethod
94 def qgPad(testGen, op, dtype):
95 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010096 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070097 return qinfo
98
99 @staticmethod
100 def computeMultiplierAndShift(scaleFp, scale32):
101 # Derived from computeMultiplierAndShiftTosaScale32
102 # Provide a floating-point scaling factor and the scale32 parameter
103 # to compute the multiplier and shift
104
105 if scale32:
106 scaleBits = 31
107 else:
108 scaleBits = 15
109
110 m, shift = math.frexp(scaleFp)
111
112 if scaleFp < 0.0:
113 m = -m
114
115 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800116 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700117
118 if multiplier == (1 << scaleBits):
119 multiplier = multiplier // 2
120 shift = shift + 1
121
122 shift = (-shift) + scaleBits
Kevin Cheng550ccc52021-03-03 11:21:43 -0800123 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
Eric Kunzee5e26762020-10-13 16:11:07 -0700124
Kevin Cheng550ccc52021-03-03 11:21:43 -0800125 assert multiplier <= (1 << scaleBits)
126 assert shift >= 0 and shift <= 63
Eric Kunzee5e26762020-10-13 16:11:07 -0700127
128 return multiplier, shift
129
130
Kevin Cheng550ccc52021-03-03 11:21:43 -0800131class TosaTensorGen:
132 """Tensor generators create a shape list for the placeholder and const tensor
133 data operands for the operator. The actual random data is generated separately for each test."""
134
Eric Kunzee5e26762020-10-13 16:11:07 -0700135 def __init__(self):
136 pass
137
138 @staticmethod
139 def tgBasic(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800140 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 shape = testGen.makeShape(rank)
142
143 shape_list = []
144 for i in range(pl + const):
145 shape_list.append(shape.copy())
146
147 return shape_list
148
149 @staticmethod
150 def tgNHWC(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800151 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700152
Kevin Cheng550ccc52021-03-03 11:21:43 -0800153 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700154
155 shape = testGen.makeShape(rank)
156
157 # Constrict the batch size?
158 if testGen.args.max_batch_size:
159 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
160
161 shape_list = []
162 for i in range(pl + const):
163 shape_list.append(shape.copy())
164
165 return shape_list
166
167 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -0800168 def tgScatter(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800169 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800170
Kevin Cheng550ccc52021-03-03 11:21:43 -0800171 assert pl == 2
172 assert const == 0
173 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800174
175 values_in_shape = testGen.makeShape(rank)
176
177 # Constrict the batch size?
178 if testGen.args.max_batch_size:
179 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
180
Kevin Cheng550ccc52021-03-03 11:21:43 -0800181 W = testGen.randInt(
182 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
183 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800184 input_shape = [values_in_shape[0], W, values_in_shape[2]]
185
186 shape_list = []
187 shape_list.append(values_in_shape.copy())
188 shape_list.append(input_shape.copy())
189
190 return shape_list
191
192 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 def tgBroadcastFuzz(testGen, op, rank):
194 shape = testGen.makeShape(rank)
195
Kevin Cheng550ccc52021-03-03 11:21:43 -0800196 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700197
198 shape_list = []
199
200 # Choose one of the inputs to broadcast
201 bcast_idx = testGen.randInt(0, pl + const)
202 for i in range(pl + const):
203 shape_bcast = shape.copy()
204
205 # If the chosen input, pick a random index to broadcast
206 if i == bcast_idx:
207 fuzz_idx = testGen.randInt(0, rank)
208 shape_bcast[fuzz_idx] = 1
209
210 shape_list.append(shape_bcast)
211
212 return shape_list
213
214 @staticmethod
215 def tgConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800216 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700217
Kevin Cheng550ccc52021-03-03 11:21:43 -0800218 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700219
220 # IFM dimensions are NHWC
221 ifm_shape = testGen.makeShape(rank)
222
223 # Constrict the batch size?
224 if testGen.args.max_batch_size:
225 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
226
227 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800228 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700229
230 # Generate a random OFM depth
231 ofm_depth = testGen.makeShape(1)[0]
232
233 # The filter dimensions are OHWI
234 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
235
236 # The bias is OC
237 bias_shape = np.asarray([ofm_depth])
238
239 return [ifm_shape, filter_shape, bias_shape]
240
241 @staticmethod
242 def tgTransposeConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800243 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700244
Kevin Cheng550ccc52021-03-03 11:21:43 -0800245 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700246
247 # IFM dimensions are NHWC
248 ifm_shape = testGen.makeShape(rank)
249
250 # Constrict the batch size?
251 if testGen.args.max_batch_size:
252 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
253
254 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800255 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700256
257 # Generate a random OFM depth
258 ofm_depth = testGen.makeShape(1)[0]
259
260 # The filter dimensions are OHWI
261 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
262
Kevin Cheng989cb052021-04-28 16:29:44 -0700263 # The bias is OC
264 bias_shape = np.asarray([ofm_depth])
265
266 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700267
268 @staticmethod
269 def tgDepthwiseConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800270 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700271
Kevin Cheng550ccc52021-03-03 11:21:43 -0800272 assert rank == 4
273 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
275 # IFM dimensions are NHWC
276 ifm_shape = testGen.makeShape(rank)
277
278 # Constrict the batch size?
279 if testGen.args.max_batch_size:
280 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
281
282 # Get the filter height/width from the operator parameters
283 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
286 # Generate a random OFM depth, but don't let it get too big because
287 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 filter_m = (
289 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
290 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700291
292 # The filter dimensions are HWCM
293 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
294
295 # The bias is M * C
296 bias_shape = np.asarray([ifm_shape[3] * filter_m])
297
298 return [ifm_shape, filter_shape, bias_shape]
299
300 @staticmethod
301 def tgFullyConnected(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800302 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700303
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700305
306 input_shape = testGen.makeShape(rank)
Les Bell2a29dc62021-07-28 08:04:55 +0100307 filter_oc = (
308 testGen.rng.integers(
309 low=testGen.args.tensor_shape_range[0],
310 high=testGen.args.tensor_shape_range[1],
311 size=1,
312 )[0]
313 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700314 filter_shape = np.asarray([filter_oc, input_shape[1]])
315
316 bias_shape = np.asarray([filter_oc])
317
318 return [input_shape, filter_shape, bias_shape]
319
320 @staticmethod
321 def tgMatmul(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800322 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700323
Kevin Cheng2d60f002021-06-09 14:18:32 -0700324 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800325 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700326
327 a_shape = testGen.makeShape(rank)
328 b_oc = testGen.makeShape(1)[0]
Kevin Cheng2d60f002021-06-09 14:18:32 -0700329 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700330
331 return [a_shape, b_shape]
332
Kevin Cheng550ccc52021-03-03 11:21:43 -0800333
Eric Kunzee5e26762020-10-13 16:11:07 -0700334class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800335 """Argument generators create exhaustive or random lists of attributes for operators that take
336 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
337 tuples where the descriptive_name is appended to the test name and the arglist is expanded
338 as arguments to the operator build function."""
339
Eric Kunzee5e26762020-10-13 16:11:07 -0700340 def __init__(self):
341 pass
342
343 @staticmethod
344 def agNone(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800345 """A trivial argument generator for operators that don't take any
346 non-tensor arguments"""
347 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700348
349 @staticmethod
350 def agAxis(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800351 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700352 axes = []
353
354 shape = shapeList[0]
355
356 for a in range(0, len(shape)):
Matthew Haddon43e37192021-07-09 14:13:02 +0100357 axes.append(("axis{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700358 return axes
359
360 @staticmethod
361 def agConv2D(testGen, opName, shapeList, dtype):
362 arg_list = []
363
364 ifm_shape = shapeList[0]
365 filter_shape = shapeList[1]
366
367 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800368 assert len(ifm_shape) == 4
369 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700370
371 maxStride = testGen.args.max_conv_stride
372 maxPadding = testGen.args.max_conv_padding + 1
373 maxDilation = testGen.args.max_conv_dilation
374
375 # Strides, padding, dilations
376 for stride in range(0, maxStride ** 2):
377 for padding in range(0, (maxPadding) ** 4):
378 for dilation in range(0, maxDilation ** 2):
379
Kevin Cheng550ccc52021-03-03 11:21:43 -0800380 s = [stride // maxStride + 1, stride % maxStride + 1]
381 p = [
382 (padding // (maxPadding * 4)) % maxPadding,
383 (padding // (maxPadding * 2)) % maxPadding,
384 (padding // (maxPadding * 1)) % maxPadding,
385 padding % maxPadding,
386 ]
387 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700388
389 # 4 padding parameters for regular conv2d
Kevin Cheng550ccc52021-03-03 11:21:43 -0800390 arg_list.append(
391 (
392 "st{}{}_pad{}{}{}{}_dilat{}{}".format(
393 s[0], s[1], p[0], p[1], p[2], p[3], d[0], d[1]
394 ),
395 [s, p, d],
396 )
397 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700398 return arg_list
399
400 @staticmethod
401 def agTransposeConv2D(testGen, opName, shapeList, dtype):
402 arg_list = []
403
404 ifm_shape = shapeList[0]
405 filter_shape = shapeList[1]
406
407 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800408 assert len(ifm_shape) == 4
409 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700410
411 maxStride = testGen.args.max_conv_stride
412 maxPadding = testGen.args.max_conv_padding + 1
413 maxDilation = testGen.args.max_conv_dilation
414
415 # Strides, padding, dilations
416 for stride in range(0, maxStride ** 2):
417 for out_padding in range(0, (maxPadding) ** 2):
418 for dilation in range(0, maxDilation ** 2):
419
Kevin Cheng550ccc52021-03-03 11:21:43 -0800420 s = [stride // maxStride + 1, stride % maxStride + 1]
421 p = [
422 (out_padding // (maxPadding * 1)) % maxPadding,
423 out_padding % maxPadding,
424 ]
425 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700426
Kevin Cheng550ccc52021-03-03 11:21:43 -0800427 oh = (
428 ifm_shape[1]
429 - filter_shape[1]
430 - (filter_shape[1] - 1) * (d[0] - 1)
431 + 2 * p[0]
432 ) // s[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700433
Kevin Cheng550ccc52021-03-03 11:21:43 -0800434 ow = (
435 ifm_shape[2]
436 - filter_shape[2]
437 - (filter_shape[2] - 1) * (d[1] - 1)
438 + 2 * p[1]
439 ) // s[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700440
441 # Output shape
Kevin Cheng550ccc52021-03-03 11:21:43 -0800442 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Eric Kunzee5e26762020-10-13 16:11:07 -0700443
Kevin Cheng550ccc52021-03-03 11:21:43 -0800444 arg_list.append(
445 (
446 "st{}{}_outpad{}{}_dilat{}{}_os{}x{}x{}x{}".format(
447 s[0],
448 s[1],
449 p[0],
450 p[1],
451 d[0],
452 d[1],
453 os[0],
454 os[1],
455 os[2],
456 os[3],
457 ),
458 [s, p, d, os],
459 )
460 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700461
462 return arg_list
463
464 @staticmethod
465 def agPad(testGen, opName, shapeList, dtype):
466 arg_list = []
467 rank = len(shapeList[0])
468
469 # Exhaustively test combinations of 0/1 padding on each side of each dimension
470 # This process might need some revision for >1 padding, but use rank**2 as a bitmask
471 # for now
472 for v in range(rank ** 2):
473
474 # Create a flat arraypadding4D
475 paddings = np.zeros((rank * 2), dtype=np.int32)
476
477 # Fill in the 1's
Kevin Cheng550ccc52021-03-03 11:21:43 -0800478 for r in range(rank * 2):
Eric Kunzee5e26762020-10-13 16:11:07 -0700479 if (v >> r) & 1:
480 paddings[r] = 1
481
482 # Reshape back to a 2D array
483 paddings = paddings.reshape((rank, 2))
484
Kevin Cheng550ccc52021-03-03 11:21:43 -0800485 arg_list.append(("pad{0:b}".format(v), [paddings]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700486
487 return arg_list
488
489 @staticmethod
490 def agPooling(testGen, opName, shapeList, dtype):
491 arg_list = []
492
493 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800494 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700495
496 maxStride = testGen.args.max_pooling_stride
497 maxKernel = testGen.args.max_pooling_kernel
498 maxPadding = testGen.args.max_pooling_padding + 1
499
500 for kernel in range(0, maxKernel ** 2):
501 for stride in range(0, maxStride ** 2):
502 for padding in range(0, maxPadding ** 4):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800503 s = [stride // maxStride + 1, stride % maxStride + 1]
504 k = [(kernel // maxKernel) + 2, (kernel % maxKernel) + 2]
505 p = [
506 (padding // (maxPadding * 4)) % maxPadding,
507 (padding // (maxPadding * 2)) % maxPadding,
508 (padding // (maxPadding * 1)) % maxPadding,
509 padding % maxPadding,
510 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700511
Kevin Cheng550ccc52021-03-03 11:21:43 -0800512 arg_list.append(
513 (
514 "st{}{}_kern{}{}_pad{}{}{}{}".format(
515 s[0], s[1], k[0], k[1], p[0], p[1], p[2], p[3]
516 ),
517 [k, s, p],
518 )
519 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700520 return arg_list
521
522 @staticmethod
523 def agCast(testGen, opName, shapeList, inDtype):
524 arg_list = []
525
526 # Enumerate the output types here
527 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800528 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700529 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800530 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700531 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800532 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700533 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800534 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700535 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800536 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700537 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800538 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700539
540 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800541 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700542
543 return arg_list
544
545 @staticmethod
546 def agRescale(testGen, opName, shapeList, inDtype):
547 arg_list = []
548
549 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100550 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
551 if inDtype == DType.UINT8 and dtype != DType.INT8:
552 # The only output dtype for UINT8 is INT8, skip all other combinations
553 continue
554 if inDtype != DType.INT8 and dtype == DType.UINT8:
555 # The only input dtype for UINT8 is INT8, skip all other combinations
556 continue
557
Kevin Cheng550ccc52021-03-03 11:21:43 -0800558 for scale32 in [False, True]:
559 for double_round in [False, True]:
560 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700561
562 if inDtype == DType.INT48 and scale32:
563 # Illegal condition. Must be scale32=False
564 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100565 if double_round and not scale32:
566 # Illegal condition. ERROR_IF(!scale32 && double_round)
567 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700568
Kevin Cheng550ccc52021-03-03 11:21:43 -0800569 arg_list.append(
570 (
571 "out{}_sc{}_dr{}_pc{}".format(
572 DTypeNames[dtype],
573 int(scale32),
574 int(double_round),
575 int(per_channel),
576 ),
577 [dtype, scale32, double_round, per_channel],
578 )
579 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700580
581 return arg_list
582
Kevin Chengaee1fac2020-11-11 13:54:06 -0800583 @staticmethod
584 def agMul(testGen, opName, shapeList, dtype):
585 arg_list = []
586
587 if dtype is DType.INT32:
588 for p in range(testGen.args.num_rand_permutations):
589
590 shift = testGen.randInt(0, 32)
591
Kevin Cheng550ccc52021-03-03 11:21:43 -0800592 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800593 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100594 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800595
596 return arg_list
597
598 @staticmethod
599 def agArithmeticRightShift(testGen, opName, shapeList, dtype):
600 arg_list = []
601
Kevin Cheng550ccc52021-03-03 11:21:43 -0800602 arg_list.append(("roundTrue", [True]))
603 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800604
605 return arg_list
606
Eric Kunzee5e26762020-10-13 16:11:07 -0700607 # Helper function for reshape. Gets some factors of a larger number.
608 @staticmethod
609 def getFactors(val, start=1):
610 factors = []
611
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100612 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700613 if (val % i) == 0:
614 factors.append(i)
615
616 return factors
617
618 @staticmethod
619 def agReshape(testGen, opName, shapeList, dtype):
620 arg_list = []
621
622 origShape = shapeList[0]
623
624 totalElements = 1
625 for s in origShape:
626 totalElements *= s
627
628 # This code is NOT fast. Fortunately, the numbers are fairly small.
629 factors = TosaArgGen.getFactors(totalElements)
630
631 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100632 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800633 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700634 continue
635
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100636 found = True
637 # escape_counter breaks while loop if it continues on for too long
638 escape_counter = 0
639 while found:
640 newShape = []
641 # Generate newShape ensuring it isn't a duplicate
642 remainingElements = totalElements
643 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100644 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100645 # pick rank-1 factors
646 newShape.append(shuffledFactors[0])
647 remainingElements = remainingElements // shuffledFactors[0]
648 shuffledFactors = testGen.rng.permutation(
649 TosaArgGen.getFactors(remainingElements)
650 )
651 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700652
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100653 # Toss in a -1 sometimes
654 minusOne = testGen.randInt(0, newRank * 4)
655 if minusOne < newRank:
656 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700657
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100658 # Check for duplicates
659 found = False
660 for name, other_shape in arg_list:
661 if other_shape[0] == newShape:
662 found = True
663 break
664
665 escape_counter += 1
666 if escape_counter >= 100:
667 break
668
669 if not found:
670 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700671
672 return arg_list
673
Eric Kunzee5e26762020-10-13 16:11:07 -0700674 @staticmethod
675 def agTranspose(testGen, opName, shapeList, dtype):
676 arg_list = []
677
678 ifm_shape = shapeList[0]
679
Jeremy Johnsona6185572021-06-21 15:55:35 +0100680 # Get all permutations
681 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700682
Jeremy Johnsona6185572021-06-21 15:55:35 +0100683 # Limit to possible permutations from shape dimension or argument setting
684 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700685
Jeremy Johnsona6185572021-06-21 15:55:35 +0100686 # Get random permutation generator that uses all permutations
687 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700688
Jeremy Johnsona6185572021-06-21 15:55:35 +0100689 # Create list of required amount of permutations
690 arg_list = [("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit)]
Eric Kunzee5e26762020-10-13 16:11:07 -0700691 return arg_list
692
693 @staticmethod
694 def agSlice(testGen, opName, shapeList, dtype):
695 arg_list = []
696
697 ifm_shape = shapeList[0]
698 rank = len(ifm_shape)
699
700 for p in range(testGen.args.num_rand_permutations):
701 begin = []
702 size = []
703
Kevin Cheng550ccc52021-03-03 11:21:43 -0800704 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700705
706 for i in range(rank):
707 if ifm_shape[i] > 1:
708 begin.append(testGen.randInt(0, ifm_shape[i]))
709 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
710
711 # Invalid slice size?
712 if size[i] == 0:
713 valid = False
714 else:
715 begin.append(0)
716 size.append(1)
717
718 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800719 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700720 return arg_list
721
722 @staticmethod
723 def agTile(testGen, opName, shapeList, dtype):
724 arg_list = []
725
726 ifm_shape = shapeList[0]
727 rank = len(ifm_shape)
728
729 for p in range(testGen.args.num_rand_permutations):
730
731 # Pick a few random, but small multiple values
732 # because otherwise this has a tendency to generate
733 # enormous tensors
734 multiples = []
735 for i in range(rank):
736 multiples.append(testGen.randInt(1, 4))
737
Kevin Cheng550ccc52021-03-03 11:21:43 -0800738 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700739
740 return arg_list
741
742 @staticmethod
743 def agResize(testGen, opName, shapeList, dtype):
744 arg_list = []
745
746 ifm_shape = shapeList[0]
747
748 for m in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
749
750 # Exclude illegal {mode, type} configurations. Pick legal output types
751 if m == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100752 outputDTypeList = [DType.INT8]
Eric Kunzee5e26762020-10-13 16:11:07 -0700753 elif m == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800754 outputDTypeList = [DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -0700755 elif m == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100756 outputDTypeList = [DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700757 elif m == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800758 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800759 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800760 outputDTypeList = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700761 else:
762 continue
763
764 for outputDType in outputDTypeList:
765 for perm in range(testGen.args.num_rand_permutations):
766
767 # Randomly generate legal output dimensions and shift
768 # and then compute the stride and offset based on them
Kevin Cheng550ccc52021-03-03 11:21:43 -0800769 output_dims = [testGen.randInt(1), testGen.randInt(1)]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800770 in_center_h = (ifm_shape[1] - 1) / 2.0
771 in_center_w = (ifm_shape[2] - 1) / 2.0
772 out_center_h = (output_dims[0] - 1) / 2.0
773 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700774
Kevin Cheng77d0f762020-11-24 10:26:32 -0800775 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
776 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
777 fp_offset_y = in_center_h - fp_stride_y * out_center_h
778 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700779
Kevin Cheng77d0f762020-11-24 10:26:32 -0800780 if outputDType == DType.FLOAT:
781 shift = 0
782 stride = [0, 0]
783 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800784 stride_fp = [fp_stride_y, fp_stride_x]
785 offset_fp = [fp_offset_y, fp_offset_x]
786 arg_list.append(
787 (
788 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Les Bell33d837e2021-08-10 08:34:43 +0100789 "N" if m == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800790 output_dims[0],
791 output_dims[1],
792 testGen.typeStr(outputDType),
793 stride_fp[0],
794 stride_fp[1],
795 offset_fp[0],
796 offset_fp[1],
797 ),
798 [
799 m,
800 stride,
801 offset,
802 shift,
803 stride_fp,
804 offset_fp,
805 output_dims,
806 dtype,
807 outputDType,
808 ],
809 )
810 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800811 else:
812 shift = 11
813 unit = float(1 << shift)
814 stride_y = int(round(fp_stride_y * unit))
815 stride_x = int(round(fp_stride_x * unit))
816 offset_y = int(round(fp_offset_y * unit))
817 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700818
Kevin Cheng550ccc52021-03-03 11:21:43 -0800819 while (
820 stride_y >= 32768
821 or stride_x >= 32768
822 or offset_y >= 32768
823 or offset_x >= 32768
824 or offset_y < -32768
825 or offset_x < -32768
826 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -0800827 shift = shift - 1
828 unit = float(1 << shift)
829 stride_y = int(round(fp_stride_y * unit))
830 stride_x = int(round(fp_stride_x * unit))
831 offset_y = int(round(fp_offset_y * unit))
832 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700833
Kevin Cheng550ccc52021-03-03 11:21:43 -0800834 stride = [stride_y, stride_x]
835 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800836
837 stride_fp = [0.0, 0.0]
838 offset_fp = [0.0, 0.0]
839
Kevin Cheng550ccc52021-03-03 11:21:43 -0800840 arg_list.append(
841 (
842 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Les Bell33d837e2021-08-10 08:34:43 +0100843 "N" if m == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800844 shift,
845 output_dims[0],
846 output_dims[1],
847 testGen.typeStr(outputDType),
848 stride[0],
849 stride[1],
850 offset[0],
851 offset[1],
852 ),
853 [
854 m,
855 stride,
856 offset,
857 shift,
858 stride_fp,
859 offset_fp,
860 output_dims,
861 dtype,
862 outputDType,
863 ],
864 )
865 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700866
867 return arg_list
868
869 def agCondIf(testGen, opName, shapeList, dtype):
870 # CondIf generates the condition values here.
871 # Convert to tensors in the build function, along with the
872 # then and else blocks
873 arg_list = []
874
875 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800876 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700877
878 return arg_list
879
880 def agWhileLoop(testGen, opName, shapeList, dtype):
881 # While loop: 0 iterations, 1, more than 1
882 arg_list = []
883
884 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800885 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700886
887 return arg_list
888
Kevin Cheng550ccc52021-03-03 11:21:43 -0800889
Eric Kunzee5e26762020-10-13 16:11:07 -0700890class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +0100891 # Maximum rank of tensor supported by test generator.
892 TOSA_TENSOR_MAX_RANK = 6
893
Eric Kunzee5e26762020-10-13 16:11:07 -0700894 def __init__(self, args):
895 self.args = args
896 self.basePath = args.output_dir
897 self.random_seed = args.random_seed
898 self.ser = None
899 self.rng = np.random.default_rng(self.random_seed)
900 self.createDynamicOpLists()
901 self.initOpListDefaults()
902 self.quantGen = TosaQuantGen()
903 # Force makeShape to do a specific starting shape
904 self.targetted_shape = None
905
906 def createSerializer(self, opName, testPath):
907 self.testPath = os.path.join(opName, testPath)
908
909 fullPath = os.path.join(self.basePath, self.testPath)
910 os.makedirs(fullPath, exist_ok=True)
911 self.ser = ts.TosaSerializer(fullPath)
912
913 def getSerializer(self):
914 return self.ser
915
916 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800917 with open(
918 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
919 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700920 fd.write(self.ser.serialize())
921
Kevin Cheng550ccc52021-03-03 11:21:43 -0800922 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
923 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700924
925 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -0700926 if dtype == DType.BOOL:
927 np_dt = np.bool
928 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700929 elif dtype == DType.INT4:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100930 return np.int32(self.rng.integers(low=-8, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700931 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100932 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
933 elif dtype == DType.UINT8:
934 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700935 elif dtype == DType.INT16:
936 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
937 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800938 return np.int32(
939 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
940 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700941 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800942 return np.int64(
943 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
944 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700945 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100946 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700947 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800948 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700949
Kevin Cheng989cb052021-04-28 16:29:44 -0700950 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700951 placeholders = []
952
Kevin Cheng989cb052021-04-28 16:29:44 -0700953 assert len(shape_list) == len(dtype_list)
954
955 for idx, shape in enumerate(shape_list):
956 arr = self.getRandTensor(shape, dtype_list[idx])
957 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700958
959 return placeholders
960
Kevin Cheng989cb052021-04-28 16:29:44 -0700961 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700962 consts = []
963
Kevin Cheng989cb052021-04-28 16:29:44 -0700964 assert len(shape_list) == len(dtype_list)
965
966 for idx, shape in enumerate(shape_list):
967 arr = self.getRandTensor(shape, dtype_list[idx])
968 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700969
970 return consts
971
972 def makeShape(self, rank):
973 if self.targetted_shape:
974 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800975 return np.int32(
976 self.rng.integers(
977 low=self.args.tensor_shape_range[0],
978 high=self.args.tensor_shape_range[1],
979 size=rank,
980 )
981 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700982
983 def setTargetShape(self, shape):
984 self.targetted_shape = shape
985
986 def randInt(self, low=0, high=256):
987 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
988
989 def getRandNumberDType(self, dtype):
990 if dtype == DType.FLOAT:
991 return self.rng.random()
992 elif dtype == DType.BOOL:
993 return self.rng.choice([False, True])
994 elif dtype == DType.INT4:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100995 low, high = (-8, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700996 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100997 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700998 elif dtype == DType.INT16:
999 low, high = (-32768, 32768)
1000 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001001 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07001002 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001003 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07001004 # Special size
1005 return np.int64(self.rng.integers(low, high, size=1))[0]
1006 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001007 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001008
1009 return np.int32(self.rng.integers(low, high, size=1))[0]
1010
1011 def shapeStr(self, shape):
1012
1013 sStr = []
1014 # Convert to strings
1015 for i in shape:
1016 sStr.append(str(i))
1017
Kevin Cheng550ccc52021-03-03 11:21:43 -08001018 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001019
1020 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07001021 if isinstance(t, list):
1022 assert len(t) >= 2
1023 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001024 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001025 if t == DType.BOOL:
1026 return "b"
1027 elif t == DType.INT4:
1028 return "i4"
1029 elif t == DType.INT8:
1030 return "i8"
1031 elif t == DType.UINT8:
1032 return "u8"
1033 elif t == DType.INT16:
1034 return "i16"
1035 elif t == DType.INT32:
1036 return "i32"
1037 elif t == DType.INT48:
1038 return "i48"
1039 elif t == DType.FLOAT:
1040 return "float"
1041 else:
1042 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001043
1044 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001045 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08001046 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07001047 return 4
1048 elif t == DType.INT8:
1049 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08001050 elif t == DType.UINT8:
1051 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07001052 elif t == DType.INT16:
1053 return 16
1054 elif t == DType.INT32:
1055 return 32
1056 elif t == DType.INT48:
1057 return 48
1058 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001059 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001060
1061 # Argument generators
1062 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
1063 # Where the string descriptor is used to generate the test name and
1064 # The build_fcn_arg_list is expanded and passed to the operator test
1065 # build function
1066
Kevin Cheng550ccc52021-03-03 11:21:43 -08001067 def build_unary(self, op, a, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001068 result_tens = OutputShaper.unaryOp(self.ser, a)
1069 self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
1070 return result_tens
1071
1072 def build_binary_broadcast(self, op, a, b):
1073 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1074 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1075 return result_tens
1076
1077 def build_binary_nonbroadcast(self, op, a, b):
1078 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
1079 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1080 return result_tens
1081
Kevin Chengaee1fac2020-11-11 13:54:06 -08001082 def build_arithmetic_right_shift(self, op, a, b, round):
1083 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1084
1085 attr = ts.TosaSerializerAttribute()
1086 attr.ArithmeticRightShiftAttribute(round)
1087
1088 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1089 return result_tens
1090
1091 def build_mul(self, op, a, b, shift):
Eric Kunzee5e26762020-10-13 16:11:07 -07001092 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1093
1094 # Special for multiply:
1095 # Force the result to INT32 for INT types
1096 if a.dtype != DType.FLOAT:
1097 result_tens.setDtype(DType.INT32)
1098
Kevin Chengaee1fac2020-11-11 13:54:06 -08001099 attr = ts.TosaSerializerAttribute()
1100 attr.MulAttribute(shift)
1101
1102 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001103 return result_tens
1104
1105 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001106 # Constant size depending on type, random values
1107 if a.dtype == DType.INT16:
1108 table_dtype = DType.INT16
1109 table_arr = self.getRandTensor([513], table_dtype)
1110 else:
1111 assert a.dtype == DType.INT8
1112 table_dtype = DType.INT8
1113 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001114
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001115 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
1116 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001117 self.ser.addOperator(op, [a.name, table_tens.name], [result_tens.name], None)
1118
1119 return result_tens
1120
1121 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07001122 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
1123 self.ser.addOperator(op, [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001124 return result_tens
1125
1126 def build_comparison(self, op, a, b):
1127 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
1128 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1129 return result_tens
1130
1131 def build_argmax(self, op, a, axis):
1132 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
1133
1134 attr = ts.TosaSerializerAttribute()
1135 attr.AxisAttribute(axis)
1136
1137 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1138 return result_tens
1139
Kevin Cheng550ccc52021-03-03 11:21:43 -08001140 def build_pool2d(self, op, input, kernel, stride, pad, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001141 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
1142
1143 attr = ts.TosaSerializerAttribute()
1144 attr.Pool2dAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07001145
1146 self.ser.addOperator(op, [input.name], [result_tens.name], attr, qinfo)
1147 return result_tens
1148
1149 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001150 assert len(padding) == 4
1151 result_tens = OutputShaper.conv2dOp(
1152 self.ser, ifm, filter, strides, padding, dilations
1153 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001154
1155 attr = ts.TosaSerializerAttribute()
1156 attr.Conv2dAttribute(padding, strides, dilations)
1157
Kevin Cheng550ccc52021-03-03 11:21:43 -08001158 self.ser.addOperator(
1159 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1160 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001161 return result_tens
1162
Kevin Cheng550ccc52021-03-03 11:21:43 -08001163 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07001164 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001165 ):
1166 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07001167 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
1168
1169 attr = ts.TosaSerializerAttribute()
1170 attr.TransposeConv2DAttribute(outpad, stride, dilation, output_shape)
1171
Kevin Cheng550ccc52021-03-03 11:21:43 -08001172 self.ser.addOperator(
Kevin Cheng989cb052021-04-28 16:29:44 -07001173 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001174 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001175 return result_tens
1176
Kevin Cheng550ccc52021-03-03 11:21:43 -08001177 def build_depthwise_conv2d(
1178 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
1179 ):
1180 result_tens = OutputShaper.depthwiseConv2dOp(
1181 self.ser, ifm, filter, strides, padding, dilations
1182 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001183
1184 attr = ts.TosaSerializerAttribute()
1185 attr.Conv2dAttribute(padding, strides, dilations)
1186
Kevin Cheng550ccc52021-03-03 11:21:43 -08001187 self.ser.addOperator(
1188 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1189 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001190 return result_tens
1191
1192 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
1193 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
1194
Kevin Cheng550ccc52021-03-03 11:21:43 -08001195 self.ser.addOperator(
1196 op, [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
1197 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001198 return result_tens
1199
1200 def build_matmul(self, op, a, b, qinfo):
1201 result_tens = OutputShaper.matmulOp(self.ser, a, b)
1202 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], None, qinfo)
1203 return result_tens
1204
1205 def build_reduce(self, op, a, axis):
1206 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
1207
1208 attr = ts.TosaSerializerAttribute()
1209 attr.AxisAttribute(axis)
1210
1211 self.ser.addOperator(op, [a.name], result_tens.name, attr)
1212 return result_tens
1213
1214 def build_clamp(self, op, a):
1215 result_tens = OutputShaper.unaryOp(self.ser, a)
1216
1217 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01001218 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001219
1220 if a.dtype == DType.FLOAT:
1221 attr.ClampAttribute(0, 0, min(v), max(v))
1222 else:
1223 attr.ClampAttribute(min(v), max(v), 0, 0)
1224
1225 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1226 return result_tens
1227
1228 def build_leaky_relu(self, op, a):
1229 result_tens = OutputShaper.unaryOp(self.ser, a)
1230 attr = ts.TosaSerializerAttribute()
1231
1232 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1233
1234 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1235 return result_tens
1236
1237 # Needs an additional type/input
1238 def build_prelu(self, op, a):
1239 result_tens = OutputShaper.unaryOp(self.ser, a)
1240
1241 self.ser.addOperator(op, [a.name], [result_tens.name])
1242 return result_tens
1243
1244 def build_relun(self, op, a):
1245 result_tens = OutputShaper.unaryOp(self.ser, a)
1246
1247 attr = ts.TosaSerializerAttribute()
1248
1249 if a.dtype == DType.FLOAT:
1250 attr.ReluNAttribute(0, self.getRandNumberDType(a.dtype))
1251 else:
1252 attr.ReluNAttribute(self.getRandNumberDType(a.dtype), 0)
1253
1254 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1255 return result_tens
1256
1257 def build_sigmoid(self, op, a):
1258 result_tens = OutputShaper.unaryOp(self.ser, a)
1259 self.ser.addOperator(op, [a.name], [result_tens.name])
1260 return result_tens
1261
1262 def build_tanh(self, op, a):
1263 result_tens = OutputShaper.unaryOp(self.ser, a)
1264 self.ser.addOperator(op, [a.name], [result_tens.name])
1265 return result_tens
1266
1267 def build_concat(self, op, a, b, axis):
1268 result_tens = OutputShaper.concatOp(self.ser, a, b, axis)
1269
1270 attr = ts.TosaSerializerAttribute()
1271 attr.AxisAttribute(axis)
1272
1273 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1274
1275 def build_pad(self, op, a, padding, qinfo):
1276 result_tens = OutputShaper.padOp(self.ser, a, padding)
1277
1278 # Need to turn the padding array into a TOSA tensor here.
1279 # This is one of the few tensor operands that does not get
1280 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08001281 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07001282
Kevin Cheng550ccc52021-03-03 11:21:43 -08001283 self.ser.addOperator(
1284 op, [a.name, padding_tens.name], [result_tens.name], None, qinfo
1285 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001286
1287 def build_reshape(self, op, a, newShape):
1288 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
1289
1290 attr = ts.TosaSerializerAttribute()
1291 attr.ReshapeAttribute(newShape)
1292
1293 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1294 return result_tens
1295
1296 def build_reverse(self, op, a, axis):
1297 result_tens = OutputShaper.unaryOp(self.ser, a)
1298
1299 attr = ts.TosaSerializerAttribute()
1300 attr.AxisAttribute(axis)
1301
1302 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1303 return result_tens
1304
1305 def build_transpose(self, op, a, perms):
1306 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
1307
Kevin Cheng550ccc52021-03-03 11:21:43 -08001308 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07001309
1310 self.ser.addOperator(op, [a.name, perms_tens.name], [result_tens.name])
1311 return result_tens
1312
1313 def build_slice(self, op, a, begin, size):
1314 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
1315
1316 attr = ts.TosaSerializerAttribute()
1317 attr.SliceAttribute(begin, size)
1318
1319 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1320 return result_tens
1321
1322 def build_tile(self, op, a, multiples):
1323 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
1324
1325 attr = ts.TosaSerializerAttribute()
1326 attr.TileAttribute(multiples)
1327
1328 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1329 return result_tens
1330
Kevin Cheng77d0f762020-11-24 10:26:32 -08001331 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07001332
1333 # Create a new indicies tensor
1334 # here with data that doesn't exceed the dimensions of the values tensor
1335
Kevin Cheng550ccc52021-03-03 11:21:43 -08001336 K = values.shape[1] # K
1337 W = self.randInt(
1338 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1339 ) # W
1340 indicies_arr = np.int32(
1341 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1342 ) # (N, W)
1343 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001344
Kevin Cheng77d0f762020-11-24 10:26:32 -08001345 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07001346
Kevin Cheng77d0f762020-11-24 10:26:32 -08001347 self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001348
1349 return result_tens
1350
Kevin Cheng77d0f762020-11-24 10:26:32 -08001351 def build_scatter(self, op, values_in, input):
1352
1353 # Create a new indicies tensor
1354 # here with data that doesn't exceed the dimensions of the values_in tensor
1355
Kevin Cheng550ccc52021-03-03 11:21:43 -08001356 K = values_in.shape[1] # K
1357 W = input.shape[1] # W
1358 indicies_arr = np.int32(
1359 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1360 ) # (N, W)
1361 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001362
1363 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
1364
Kevin Cheng550ccc52021-03-03 11:21:43 -08001365 self.ser.addOperator(
1366 op, [values_in.name, indicies.name, input.name], [result_tens.name]
1367 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001368
1369 return result_tens
1370
Kevin Cheng550ccc52021-03-03 11:21:43 -08001371 def build_resize(
1372 self,
1373 op,
1374 input,
1375 mode,
1376 stride,
1377 offset,
1378 shift,
1379 stride_fp,
1380 offset_fp,
1381 output_dims,
1382 input_dtype,
1383 output_dtype,
1384 ):
1385 result_tens = OutputShaper.resizeOp(
1386 self.ser,
1387 input,
1388 mode,
1389 stride,
1390 offset,
1391 shift,
1392 stride_fp,
1393 offset_fp,
1394 output_dims,
1395 input_dtype,
1396 output_dtype,
1397 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001398
1399 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001400
Kevin Cheng550ccc52021-03-03 11:21:43 -08001401 attr.ResizeAttribute(
1402 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
1403 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001404
1405 self.ser.addOperator(op, [input.name], [result_tens.name], attr)
1406 return result_tens
1407
1408 def build_identityn(self, op, val, val2):
1409
Kevin Cheng550ccc52021-03-03 11:21:43 -08001410 result_tens = OutputShaper.unaryOp(self.ser, val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001411 result_tens2 = OutputShaper.unaryOp(self.ser, val2)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001412 self.ser.addOperator(
1413 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1414 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001415 return result_tens
1416
1417 def build_placeholder(self, op, val):
1418 # Add an identity op to avoid warning in the reference model
1419 return self.build_unary(Op.IDENTITY, val)
1420
1421 # Type Conversion
1422 def build_cast(self, op, val, out_dtype):
1423 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1424 self.ser.addOperator(op, [val.name], [result_tens.name])
1425 return result_tens
1426
1427 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
1428 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1429
1430 if per_channel:
1431 nc = val.shape[-1]
1432 else:
1433 nc = 1
1434
1435 in_type_width = self.typeWidth(val.dtype)
1436 out_type_width = self.typeWidth(out_dtype)
1437
Kevin Cheng3a478572021-01-22 17:21:02 -08001438 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001439 input_zp = self.randInt(-128, 128)
1440 in_type_width = in_type_width + 1
1441 elif val.dtype == DType.UINT8:
1442 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07001443 in_type_width = in_type_width + 1
1444 else:
1445 input_zp = 0
1446
Kevin Cheng3a478572021-01-22 17:21:02 -08001447 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001448 output_zp = self.randInt(-128, 128)
1449 out_type_width = out_type_width + 1
1450 elif out_dtype == DType.UINT8:
1451 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07001452 out_type_width = out_type_width + 1
1453 else:
1454 output_zp = 0
1455
1456 # Calculate scale based on:
1457 # scale = a *(2^output_width)/(2^input_width))
1458
1459 a = np.float32(self.rng.random(size=[nc]))
1460 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1461
1462 if scale32:
1463 pass
1464 # Cap the scaling at 2^15 - 1 for scale16
1465 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1466 else:
1467 # Cap the scaling at 2^15 - 1 for scale16
1468 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1469
Kevin Cheng550ccc52021-03-03 11:21:43 -08001470 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001471
1472 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1473 shift_arr = np.int32(np.zeros(shape=[nc]))
1474
1475 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001476 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1477 scale_arr[i], scale32
1478 )
Kevin Chengaee1fac2020-11-11 13:54:06 -08001479 if shift_arr[i] < 2 or shift_arr[i] > 62:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001480 self.ser.setExpectedFailure(True, "OpRescale: invalid shift value")
Eric Kunzee5e26762020-10-13 16:11:07 -07001481
Kevin Cheng550ccc52021-03-03 11:21:43 -08001482 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07001483
1484 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001485 attr.RescaleAttribute(
1486 input_zp,
1487 output_zp,
1488 multiplier_arr,
1489 shift_arr,
1490 scale32,
1491 double_round,
1492 per_channel,
1493 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001494
1495 self.ser.addOperator(op, [val.name], [result_tens.name], attr)
1496 return result_tens
1497
1498 def build_cond_if_const(self, op, then_tens, else_tens, cond):
1499 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1500 # (except for the generated shap) and the condition. Build Then/Else blocks
1501 # and fill them with const nodes for the body.
1502
1503 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001504 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001505
1506 # Make then/else tensors
1507 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01001508 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1509 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001510
1511 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001512 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001513
1514 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001515 then_block = "THEN_BLOCK"
1516 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001517 attr = ts.TosaSerializerAttribute()
1518 attr.CondIfAttribute(then_block, else_block)
1519
1520 # Finally, build the op and the two blocks
1521 self.ser.addOperator(op, [cond_tens.name], [result_tens.name], attr)
1522
1523 self.ser.startBasicBlock(then_block)
1524 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001525 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001526 self.ser.addOutputTensor(then_tens)
1527
1528 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001529 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001530 self.ser.addOutputTensor(else_tens)
1531
1532 return result_tens
1533
1534 def build_cond_if_binary(self, op, a, b, cond):
1535 # For cond_if with a binary op in the then/else blocks, take a and b and
1536 # alternately add or subtract them based on the condition
1537
1538 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001539 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001540
Kevin Cheng550ccc52021-03-03 11:21:43 -08001541 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001542 self.ser.currBasicBlock.addOutput(result_tens.name)
1543
1544 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001545 then_block = "THEN_BLOCK"
1546 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001547 attr = ts.TosaSerializerAttribute()
1548 attr.CondIfAttribute(then_block, else_block)
1549
1550 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001551 self.ser.addOperator(
1552 op, [cond_tens.name, a.name, b.name], [result_tens.name], attr
1553 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001554
1555 self.ser.startBasicBlock(then_block)
1556 self.ser.addInputTensor(a)
1557 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001558 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001559 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
1560
1561 self.ser.startBasicBlock(else_block)
1562 self.ser.addInputTensor(a)
1563 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001564 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001565 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
1566
1567 return result_tens
1568
1569 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001570 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001571
Kevin Cheng550ccc52021-03-03 11:21:43 -08001572 cond_block = "COND_BLOCK"
1573 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001574
1575 attr = ts.TosaSerializerAttribute()
1576 attr.WhileLoopAttribute(cond_block, body_block)
1577
1578 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001579 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001580 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001581 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001582
1583 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001584 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1585 a_out = self.ser.addIntermediate(a.shape, a.dtype)
1586 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001587
1588 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001589 self.ser.addOperator(
1590 op,
1591 [iter.name, a.name, acc.name],
1592 [iter_out.name, a_out.name, acc_out.name],
1593 attr,
1594 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001595
1596 # COND block (input: iter, output: cond_tens )
1597 self.ser.startBasicBlock(cond_block)
1598 self.ser.addInputTensor(iter)
1599 self.ser.addInputTensor(a)
1600 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001601 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
1602 cond_tens = self.ser.addOutput([], DType.BOOL)
1603 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001604
1605 # BODY block (input: a, acc, iter, output: a, acc, iter)
1606 # Note that local intermediate tensors need to be declared here for the outputs
1607 self.ser.startBasicBlock(body_block)
1608 self.ser.addInputTensor(iter)
1609 self.ser.addInputTensor(a)
1610 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001611 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
1612 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1613 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001614 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1615 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1616 self.ser.addOutputTensor(iter_body_out)
1617 self.ser.addOutputTensor(a)
1618 self.ser.addOutputTensor(acc_body_out)
1619
1620 return acc_out
1621
Kevin Cheng550ccc52021-03-03 11:21:43 -08001622 def genOpTestList(
1623 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None
1624 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001625
1626 try:
1627 op = self.TOSA_OP_LIST[opName]
1628 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001629 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001630
1631 # Initialize a new random number generator
1632 self.rng = np.random.default_rng(self.random_seed)
1633
Kevin Cheng550ccc52021-03-03 11:21:43 -08001634 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001635
1636 # Generate the lists of arguments
Kevin Cheng550ccc52021-03-03 11:21:43 -08001637 rmin, rmax = op["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001638
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001639 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
1640 default_test_rank_range = range(1, 5)
1641
Eric Kunzee5e26762020-10-13 16:11:07 -07001642 # Test list consists of a tuple of:
1643 # (opName, testNameStr, dtype, shapeList, argumentsList)
1644 testList = []
1645
1646 if not shapeFilter:
1647 shapeFilter = [None]
1648
1649 for r in range(rmin, rmax + 1):
1650
1651 # Filter out the rank?
1652 if rankFilter is not None and r not in rankFilter:
1653 continue
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001654 if rankFilter is None and shapeFilter[0] is None and r not in default_test_rank_range:
1655 continue
Eric Kunzee5e26762020-10-13 16:11:07 -07001656
Kevin Cheng550ccc52021-03-03 11:21:43 -08001657 for t in op["types"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001658
1659 # Filter tests based on dtype?
1660 if dtypeFilter is not None:
Les Bell30e46802021-07-23 09:43:31 +01001661 if not (t in dtypeFilter or (isinstance(t, list) and t[0] in dtypeFilter)):
Eric Kunzee5e26762020-10-13 16:11:07 -07001662 continue
1663
1664 # Create the placeholder and const tensors
1665 for shape in shapeFilter:
1666 # A None shape chooses a random shape of a given rank
1667
1668 # Filter out by rank
1669 if shape is not None and len(shape) != r:
1670 continue
1671
1672 self.setTargetShape(shape)
1673 shapeList = tgen_fcn(self, op, r)
1674
1675 shapeStr = self.shapeStr(shapeList[0])
1676 typeStr = self.typeStr(t)
1677
1678 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
1679 argList = []
1680 if agen_fcn:
1681 argList = agen_fcn(self, opName, shapeList, t)
1682 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001683 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07001684
1685 for argStr, args in argList:
1686 if argStr:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001687 testStr = "{}_{}_{}_{}".format(
1688 opName, shapeStr, typeStr, argStr
1689 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001690 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001691 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001692
1693 testList.append((opName, testStr, t, shapeList, args))
1694
1695 return testList
1696
Kevin Cheng989cb052021-04-28 16:29:44 -07001697 def serializeTest(self, opName, testStr, dtype_or_dtypeList, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07001698 try:
1699 op = self.TOSA_OP_LIST[opName]
1700 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001701 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001702
1703 # Create a serializer
1704 self.createSerializer(opName, testStr)
1705
Kevin Cheng550ccc52021-03-03 11:21:43 -08001706 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
1707 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07001708 num_operands = pCount + cCount
1709
1710 if isinstance(dtype_or_dtypeList, list):
1711 dtypeList = dtype_or_dtypeList
1712 else:
1713 dtypeList = [dtype_or_dtypeList] * (num_operands)
1714
1715 assert (
1716 len(shapeList) == num_operands
1717 ), "shapeList length {} must match number of operands {}".format(
1718 len(shapeList), num_operands
1719 )
1720 assert (
1721 len(dtypeList) == num_operands
1722 ), "dtypeList length {} must match number of operands {}".format(
1723 len(dtypeList), num_operands
1724 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001725
1726 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001727 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001728 except KeyError:
1729 qgen = None
1730
1731 # Build the random tensor operands and the test
1732 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08001733
1734 # If test is ArithmeticRightShift, force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001735 if op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
1736 assert (
1737 pCount == 2 and cCount == 0
1738 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08001739
1740 placeholders = []
1741 for idx, shape in enumerate(shapeList[:]):
1742 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07001743 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001744 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001745 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001746 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001747 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001748 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
1749 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001750 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08001751 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001752 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07001753 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001754
1755 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01001756 elif op["op"] == Op.SELECT:
1757 # Set datatype of condition tensor to boolean
1758 dtypeList[0] = DType.BOOL
1759 tens.extend(
1760 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
1761 )
1762 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001763 elif op["op"] == Op.DIV:
1764 assert (
1765 pCount == 2 and cCount == 0
1766 ), "Op.Div must have 2 placeholders, 0 consts"
1767
1768 placeholders = []
1769
1770 # Two invalid cases for Op.DIV:
1771 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07001772 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001773 while True:
1774 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
1775 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
1776
1777 if (divisor_arr == 0).any():
1778 continue
1779
Kevin Cheng47315e12021-05-13 17:41:28 -07001780 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001781 continue
1782
1783 break
1784
1785 placeholders.append(
1786 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1787 )
1788 placeholders.append(
1789 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1790 )
1791
1792 tens.extend(placeholders)
1793 elif op["op"] == Op.MUL:
1794 assert (
1795 pCount == 2 and cCount == 0
1796 ), "Op.MUL must have 2 placeholders, 0 consts"
1797
1798 if dtypeList[0] == DType.FLOAT:
1799 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
1800 else:
1801 placeholders = []
1802
1803 # Make sure multiply result in int32 range
1804 shift = testArgs[0]
1805 if dtypeList[0] == DType.INT8:
1806 num_bits = 8
1807 elif dtypeList[0] == DType.INT16:
1808 num_bits = 16
1809 elif dtypeList[0] == DType.INT32:
1810 num_bits = 32
1811 else:
1812 raise Exception("OpMul: invalid input dtype")
1813
1814 for idx, shape in enumerate(shapeList[:]):
1815 low = -(2 ** (num_bits - 1))
1816 high = (2 ** (num_bits - 1)) - 1
1817
1818 a_arr = np.int32(
1819 self.rng.integers(low=low, high=high, size=shapeList[0])
1820 )
1821 b_arr = np.int32(
1822 self.rng.integers(low=low, high=high, size=shapeList[1])
1823 )
1824
1825 i = 0
1826 while True:
1827
1828 a_arr_64 = a_arr.astype(np.int64)
1829 b_arr_64 = b_arr.astype(np.int64)
1830
1831 if shift > 0:
1832 rounding = 1 << (shift - 1)
1833 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
1834 else:
1835 result_arr = a_arr_64 * b_arr_64
1836
1837 if (result_arr > -(2 ** 31)).all() and (
1838 result_arr <= ((2 ** 31) - 1)
1839 ).all():
1840 break
1841
1842 i = i + 1
1843 a_arr = a_arr // 2
1844 b_arr = b_arr // 2
1845
1846 placeholders.append(
1847 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1848 )
1849 placeholders.append(
1850 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1851 )
1852
1853 tens.extend(placeholders)
Kevin Chengaee1fac2020-11-11 13:54:06 -08001854 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001855 tens.extend(
1856 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
1857 )
1858 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001859
1860 if qgen is not None:
Les Bell30e46802021-07-23 09:43:31 +01001861 qinfo = qgen(self, op, dtype_or_dtypeList)
Eric Kunzee5e26762020-10-13 16:11:07 -07001862 else:
1863 qinfo = None
1864
1865 try:
1866 if qinfo is not None:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001867 resultName = build_fcn(self, op["op"], *tens, *testArgs, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001868 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001869 resultName = build_fcn(self, op["op"], *tens, *testArgs)
Eric Kunzee5e26762020-10-13 16:11:07 -07001870 except TypeError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001871 print(
1872 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
1873 build_fcn, tens, testArgs
1874 )
1875 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001876 raise e
1877
1878 # Save the serialized test
Kevin Cheng550ccc52021-03-03 11:21:43 -08001879 self.serialize("test")
Eric Kunzee5e26762020-10-13 16:11:07 -07001880
1881 def createDynamicOpLists(self):
1882
1883 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng550ccc52021-03-03 11:21:43 -08001884 KERNELS = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07001885
1886 for k in KERNELS:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001887 testName = "conv2d_{}x{}".format(k[0], k[1])
1888 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
1889 self.TOSA_OP_LIST[testName]["filter"] = k
1890 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001891
Kevin Cheng550ccc52021-03-03 11:21:43 -08001892 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
1893 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1894 "depthwise_conv2d_TEMPLATE"
1895 ].copy()
1896 self.TOSA_OP_LIST[testName]["filter"] = k
1897 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001898
Kevin Cheng550ccc52021-03-03 11:21:43 -08001899 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
1900 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1901 "transpose_conv2d_TEMPLATE"
1902 ].copy()
1903 self.TOSA_OP_LIST[testName]["filter"] = k
1904 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001905
1906 # Delete any templates after having created any dynamic ops
1907 # This is a two-pass operation because it's bad practice to delete
1908 # keys from dictionaries while iterating
1909 keyList = []
1910 for k in self.TOSA_OP_LIST:
1911 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001912 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07001913 keyList.append(k)
1914 continue
1915 except KeyError:
1916 pass
1917
1918 for k in keyList:
1919 del self.TOSA_OP_LIST[k]
1920
1921 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001922 """Fill in default fields for ops if they aren't already specified.
1923 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07001924 for op in self.TOSA_OP_LIST:
1925
1926 # Required fields
1927 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001928 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001929 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001930 raise Exception(
1931 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
1932 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001933
1934 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001935 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001936 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001937 raise Exception(
1938 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
1939 op
1940 )
1941 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001942
1943 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001944 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001945 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001946 raise Exception(
1947 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
1948 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001949
1950 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001951 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001952 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001953 raise Exception(
1954 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
1955 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001956
1957 # Put in default rank range, if missing
1958 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001959 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001960 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001961 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07001962
1963 # Tensor operator list
1964 # 'op': op name
1965 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08001966 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
1967 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001968 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
1969 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08001970 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07001971
Kevin Cheng550ccc52021-03-03 11:21:43 -08001972 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
1973 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07001974
Kevin Cheng550ccc52021-03-03 11:21:43 -08001975 TYPE_BOOL = [DType.BOOL]
1976 TYPE_FI32 = [DType.FLOAT, DType.INT32]
1977 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
1978 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07001979
Kevin Cheng550ccc52021-03-03 11:21:43 -08001980 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07001981
Kevin Cheng989cb052021-04-28 16:29:44 -07001982 TYPE_CONV2D = [
1983 [DType.INT8, DType.INT8, DType.INT32],
1984 [DType.INT16, DType.INT8, DType.INT48],
1985 DType.FLOAT,
1986 ]
1987
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001988 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07001989
1990 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08001991 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08001992 "argmax": {
1993 "op": Op.ARGMAX,
1994 "operands": (1, 0),
1995 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
1996 "types": TYPE_NARROW_INT_FP,
1997 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001998 "avg_pool2d": {
1999 "op": Op.AVG_POOL2D,
2000 "operands": (1, 0),
2001 "rank": (4, 4),
2002 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2003 "qgen": TosaQuantGen.qgUnary,
2004 "types": TYPE_NARROW_INT_FP,
2005 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002006 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002007 "conv2d_TEMPLATE": {
2008 "op": Op.CONV2D,
2009 "operands": (1, 2),
2010 "rank": (4, 4),
2011 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
2012 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002013 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002014 "template": True,
2015 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002016 # Conv3d TBD
Eric Kunzee5e26762020-10-13 16:11:07 -07002017 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002018 "depthwise_conv2d_TEMPLATE": {
2019 "op": Op.DEPTHWISE_CONV2D,
2020 "operands": (1, 2),
2021 "filter": [1, 1],
2022 "rank": (4, 4),
2023 "build_fcn": (
2024 build_depthwise_conv2d,
2025 TosaTensorGen.tgDepthwiseConv2D,
2026 TosaArgGen.agConv2D,
2027 ),
2028 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002029 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002030 "template": True,
2031 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002032 "fully_connected": {
2033 "op": Op.FULLY_CONNECTED,
2034 "operands": (1, 2),
2035 "rank": (2, 2),
2036 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
2037 "qgen": TosaQuantGen.qgConv,
2038 "types": TYPE_CONV2D,
2039 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002040 "matmul": {
2041 "op": Op.MATMUL,
2042 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002043 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08002044 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
2045 "qgen": TosaQuantGen.qgMatmul,
2046 "types": TYPE_NARROW_INT_FP,
2047 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002048 "max_pool2d": {
2049 "op": Op.MAX_POOL2D,
2050 "operands": (1, 0),
2051 "rank": (4, 4),
2052 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2053 "types": TYPE_NARROW_INT_FP,
2054 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002055 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002056 "transpose_conv2d_TEMPLATE": {
2057 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002058 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002059 "rank": (4, 4),
2060 "build_fcn": (
2061 build_transpose_conv2d,
2062 TosaTensorGen.tgTransposeConv2D,
2063 TosaArgGen.agTransposeConv2D,
2064 ),
2065 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002066 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002067 "template": True,
2068 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002069 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002070 "clamp": {
2071 "op": Op.CLAMP,
2072 "operands": (1, 0),
2073 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
2074 "types": TYPE_NARROW_INT_FP,
2075 },
2076 "relun": {
2077 "op": Op.RELUN,
2078 "operands": (1, 0),
2079 "build_fcn": (build_relun, TosaTensorGen.tgBasic, None),
2080 "types": TYPE_FI32,
2081 },
2082 "sigmoid": {
2083 "op": Op.SIGMOID,
2084 "operands": (1, 0),
2085 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
2086 "types": TYPE_FP,
2087 },
2088 "tanh": {
2089 "op": Op.TANH,
2090 "operands": (1, 0),
2091 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
2092 "types": TYPE_FP,
2093 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002094 # Elementwise Binary Operators
2095 "add": {
2096 "op": Op.ADD,
2097 "operands": (2, 0),
2098 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2099 "types": TYPE_FI32,
2100 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002101 "arithmetic_right_shift": {
2102 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2103 "operands": (2, 0),
2104 "build_fcn": (
2105 build_arithmetic_right_shift,
2106 TosaTensorGen.tgBroadcastFuzz,
2107 TosaArgGen.agArithmeticRightShift,
2108 ),
2109 "types": TYPE_INT,
2110 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002111 "bitwise_and": {
2112 "op": Op.BITWISE_AND,
2113 "operands": (2, 0),
2114 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2115 "types": TYPE_INT,
2116 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002117 "bitwise_or": {
2118 "op": Op.BITWISE_OR,
2119 "operands": (2, 0),
2120 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2121 "types": TYPE_INT,
2122 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002123 "bitwise_xor": {
2124 "op": Op.BITWISE_XOR,
2125 "operands": (2, 0),
2126 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2127 "types": TYPE_INT,
2128 },
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002129 "div": {
2130 "op": Op.DIV,
2131 "operands": (2, 0),
2132 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2133 "types": [DType.INT32],
2134 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002135 "logical_and": {
2136 "op": Op.LOGICAL_AND,
2137 "operands": (2, 0),
2138 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2139 "types": TYPE_BOOL,
2140 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002141 "logical_left_shift": {
2142 "op": Op.LOGICAL_LEFT_SHIFT,
2143 "operands": (2, 0),
2144 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2145 "types": TYPE_INT,
2146 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002147 "logical_right_shift": {
2148 "op": Op.LOGICAL_RIGHT_SHIFT,
2149 "operands": (2, 0),
2150 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2151 "types": TYPE_INT,
2152 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002153 "logical_or": {
2154 "op": Op.LOGICAL_OR,
2155 "operands": (2, 0),
2156 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2157 "types": TYPE_BOOL,
2158 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002159 "logical_xor": {
2160 "op": Op.LOGICAL_XOR,
2161 "operands": (2, 0),
2162 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2163 "types": TYPE_BOOL,
2164 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002165 "maximum": {
2166 "op": Op.MAXIMUM,
2167 "operands": (2, 0),
2168 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2169 "types": TYPE_FI32,
2170 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002171 "minimum": {
2172 "op": Op.MINIMUM,
2173 "operands": (2, 0),
2174 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2175 "types": TYPE_FI32,
2176 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002177 "mul": {
2178 "op": Op.MUL,
2179 "operands": (2, 0),
2180 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
2181 "types": TYPE_INT_FP,
2182 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002183 "pow": {
2184 "op": Op.POW,
2185 "operands": (2, 0),
2186 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
2187 "types": TYPE_FP,
2188 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002189 "sub": {
2190 "op": Op.SUB,
2191 "operands": (2, 0),
2192 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2193 "types": TYPE_FI32,
2194 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002195 "table": {
2196 "op": Op.TABLE,
2197 # Use the automatic generation functions to create the input array
2198 # but create the table tensor in the build function, as it may be
2199 # a different type from the input
2200 "operands": (1, 0),
2201 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002202 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08002203 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002204 # Elementwise Unary operators
2205 "abs": {
2206 "op": Op.ABS,
2207 "operands": (1, 0),
2208 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2209 "types": TYPE_FI32,
2210 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002211 "bitwise_not": {
2212 "op": Op.BITWISE_NOT,
2213 "operands": (1, 0),
2214 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2215 "types": TYPE_INT,
2216 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002217 "ceil": {
2218 "op": Op.CEIL,
2219 "operands": (1, 0),
2220 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2221 "types": TYPE_FP,
2222 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002223 "clz": {
2224 "op": Op.CLZ,
2225 "operands": (1, 0),
2226 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2227 "types": [DType.INT32],
2228 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002229 "exp": {
2230 "op": Op.EXP,
2231 "operands": (1, 0),
2232 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2233 "types": TYPE_FP,
2234 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002235 "floor": {
2236 "op": Op.FLOOR,
2237 "operands": (1, 0),
2238 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2239 "types": TYPE_FP,
2240 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002241 "log": {
2242 "op": Op.LOG,
2243 "operands": (1, 0),
2244 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2245 "types": TYPE_FP,
2246 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002247 "logical_not": {
2248 "op": Op.LOGICAL_NOT,
2249 "operands": (1, 0),
2250 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2251 "types": TYPE_BOOL,
2252 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002253 "negate": {
2254 "op": Op.NEGATE,
2255 "operands": (1, 0),
2256 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2257 "qgen": TosaQuantGen.qgUnary,
2258 "types": TYPE_INT_FP,
2259 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002260 "reciprocal": {
2261 "op": Op.RECIPROCAL,
2262 "operands": (1, 0),
2263 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2264 "types": TYPE_FP,
2265 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002266 "rsqrt": {
2267 "op": Op.RSQRT,
2268 "operands": (1, 0),
2269 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2270 "types": TYPE_FP,
2271 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002272 # Elementwise Ternary operators
2273 "select": {
2274 "op": Op.SELECT,
2275 "operands": (3, 0),
2276 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
2277 "types": TYPE_FIB,
2278 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002279 # Comparison operators
2280 "equal": {
2281 "op": Op.EQUAL,
2282 "operands": (2, 0),
2283 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2284 "types": TYPE_FI32,
2285 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002286 "greater_equal": {
2287 "op": Op.GREATER_EQUAL,
2288 "operands": (2, 0),
2289 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2290 "types": TYPE_FI32,
2291 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002292 "greater": {
2293 "op": Op.GREATER,
2294 "operands": (2, 0),
2295 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2296 "types": TYPE_FI32,
2297 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002298 # Reduction operators
2299 "reduce_all": {
2300 "op": Op.REDUCE_ALL,
2301 "operands": (1, 0),
2302 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2303 "types": TYPE_BOOL,
2304 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002305 "reduce_any": {
2306 "op": Op.REDUCE_ANY,
2307 "operands": (1, 0),
2308 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2309 "types": TYPE_BOOL,
2310 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002311 "reduce_max": {
2312 "op": Op.REDUCE_MAX,
2313 "operands": (1, 0),
2314 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2315 "types": TYPE_INT_FP,
2316 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002317 "reduce_min": {
2318 "op": Op.REDUCE_MAX,
2319 "operands": (1, 0),
2320 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2321 "types": TYPE_INT_FP,
2322 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002323 "reduce_product": {
2324 "op": Op.REDUCE_PRODUCT,
2325 "operands": (1, 0),
2326 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2327 "types": TYPE_FP,
2328 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002329 "reduce_sum": {
2330 "op": Op.REDUCE_SUM,
2331 "operands": (1, 0),
2332 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2333 "types": TYPE_FI32,
2334 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002335 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002336 "concat": {
2337 "op": Op.CONCAT,
2338 "operands": (2, 0),
2339 "build_fcn": (build_concat, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2340 "types": TYPE_FIB,
2341 },
2342 "pad": {
2343 "op": Op.PAD,
2344 "operands": (1, 0),
2345 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
2346 "qgen": TosaQuantGen.qgPad,
2347 "types": TYPE_FIB,
2348 },
2349 "reshape": {
2350 "op": Op.RESHAPE,
2351 "operands": (1, 0),
2352 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
2353 "types": TYPE_FIB,
2354 },
2355 "reverse": {
2356 "op": Op.REVERSE,
2357 "operands": (1, 0),
2358 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2359 "types": TYPE_FIB,
2360 },
2361 "slice": {
2362 "op": Op.SLICE,
2363 "operands": (1, 0),
2364 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
2365 "types": TYPE_FIB,
2366 },
2367 "tile": {
2368 "op": Op.TILE,
2369 "operands": (1, 0),
2370 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
2371 "types": TYPE_FIB,
2372 },
2373 "transpose": {
2374 "op": Op.TRANSPOSE,
2375 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01002376 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002377 "build_fcn": (
2378 build_transpose,
2379 TosaTensorGen.tgBasic,
2380 TosaArgGen.agTranspose,
2381 ),
2382 "types": TYPE_FIB,
2383 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002384 # Data nodes
2385 "const": {
2386 "op": Op.CONST,
2387 "operands": (1, 0),
2388 "build_fcn": (build_placeholder, TosaTensorGen.tgBasic, None),
2389 "types": TYPE_FIB,
2390 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002391 "identity": {
2392 "op": Op.IDENTITY,
2393 "operands": (1, 0),
2394 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2395 "types": TYPE_FIB,
2396 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002397 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08002398 "gather": {
2399 "op": Op.GATHER,
2400 # Only specify 'values' tensor here. 'indices' is generated in op building stage
2401 "operands": (1, 0),
2402 "rank": (3, 3),
2403 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
2404 "types": TYPE_INT_FP,
2405 },
2406 "scatter": {
2407 "op": Op.SCATTER,
2408 # Only specify 'values_in' tensor here.
2409 #'indices' and 'input' are generated in op building stage
2410 "operands": (2, 0),
2411 "rank": (3, 3),
2412 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
2413 "types": TYPE_INT_FP,
2414 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002415 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08002416 "resize": {
2417 "op": Op.RESIZE,
2418 "operands": (1, 0),
2419 "rank": (4, 4),
2420 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
2421 "types": [DType.INT8, DType.INT16, DType.FLOAT],
2422 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002423 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08002424 "cast": {
2425 "op": Op.CAST,
2426 "operands": (1, 0),
2427 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
2428 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
2429 },
2430 "rescale": {
2431 "op": Op.RESCALE,
2432 "operands": (1, 0),
2433 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002434 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002435 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002436 # Custom
2437 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08002438 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07002439 # Two varients of cond_if, one that generates one of two constant tensors (no
2440 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
2441 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002442 "cond_if_const": {
2443 "op": Op.COND_IF,
2444 "operands": (0, 2),
2445 "build_fcn": (
2446 build_cond_if_const,
2447 TosaTensorGen.tgBasic,
2448 TosaArgGen.agCondIf,
2449 ),
2450 "types": [DType.BOOL],
2451 },
2452 "cond_if_binary": {
2453 "op": Op.COND_IF,
2454 "operands": (2, 0),
2455 "build_fcn": (
2456 build_cond_if_binary,
2457 TosaTensorGen.tgBasic,
2458 TosaArgGen.agCondIf,
2459 ),
2460 "types": TYPE_FI32,
2461 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002462 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002463 "while_loop": {
2464 "op": Op.WHILE_LOOP,
2465 "operands": (0, 1),
2466 "build_fcn": (
2467 build_while_loop,
2468 TosaTensorGen.tgBasic,
2469 TosaArgGen.agWhileLoop,
2470 ),
2471 "types": [DType.INT32],
2472 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002473 }
2474
Kevin Cheng550ccc52021-03-03 11:21:43 -08002475
Eric Kunzee5e26762020-10-13 16:11:07 -07002476class OutputShaper:
2477 # Methods in this class compute the expected output shape and datatype
2478 # for common classes of operations
2479 def __init__(self):
2480 pass
2481
2482 # These methods return arguments that can be used for
2483 # creating a new output tensor
2484 @staticmethod
2485 def binaryBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002486 assert len(a.shape) == len(b.shape)
2487 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002488
2489 shape = []
2490 for i in range(len(a.shape)):
2491 if a.shape[i] == 1:
2492 shape.append(b.shape[i])
2493 else:
2494 shape.append(a.shape[i])
2495
Kevin Cheng550ccc52021-03-03 11:21:43 -08002496 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002497
2498 @staticmethod
2499 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002500 assert len(a.shape) == len(b.shape)
2501 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002502
2503 shape = []
2504 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002505 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07002506 shape.append(a.shape[i])
2507
Kevin Cheng550ccc52021-03-03 11:21:43 -08002508 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002509
2510 @staticmethod
2511 def unaryOp(ser, a):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002512 return ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002513
2514 @staticmethod
2515 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002516 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
2517 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002518
2519 shape = []
2520 for i in range(len(a.shape)):
2521 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
2522
Kevin Cheng550ccc52021-03-03 11:21:43 -08002523 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002524
2525 @staticmethod
2526 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002527 assert len(a.shape) == len(b.shape)
2528 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002529
2530 # Do broadcast
2531 shape = []
2532 for i in range(len(a.shape)):
2533 if a.shape[i] == 1:
2534 shape.append(b.shape[i])
2535 else:
2536 shape.append(a.shape[i])
2537
2538 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08002539 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07002540
2541 @staticmethod
2542 def reduceOp(ser, a, axis):
2543
2544 shape = a.shape.copy()
2545
2546 shape[axis] = 1
2547
Kevin Cheng550ccc52021-03-03 11:21:43 -08002548 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002549
2550 @staticmethod
2551 def argmaxOp(ser, a, axis):
2552 shape = a.shape.copy()
2553 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002554 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002555
2556 @staticmethod
2557 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
2558
2559 # IFM: NHWC
2560 # Filter: OHWI
2561 # OFM: NHWC
2562
2563 if len(padding) == 2:
2564 # Expand padding to 4 parameters in the case of transpose_conv2d
2565 # From H,W to T,B,L,R
2566 padding = [padding[0], padding[0], padding[1], padding[1]]
2567
Kevin Cheng550ccc52021-03-03 11:21:43 -08002568 h = (
2569 ifm.shape[1]
2570 - filter.shape[1]
2571 - (filter.shape[1] - 1) * (dilations[0] - 1)
2572 + padding[0]
2573 + padding[1]
2574 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002575
Kevin Cheng550ccc52021-03-03 11:21:43 -08002576 w = (
2577 ifm.shape[2]
2578 - filter.shape[2]
2579 - (filter.shape[2] - 1) * (dilations[1] - 1)
2580 + padding[2]
2581 + padding[3]
2582 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002583
2584 if h <= 0 or w <= 0:
2585 # Invalid test parameters?
2586 h = 0
2587 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002588 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002589
2590 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
2591
Kevin Cheng3a478572021-01-22 17:21:02 -08002592 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002593 out_dtype = DType.INT32
2594 elif ifm.dtype == DType.INT16:
2595 out_dtype = DType.INT48
2596 elif ifm.dtype == DType.FLOAT:
2597 out_dtype = DType.FLOAT
2598 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002599 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002600
Kevin Cheng550ccc52021-03-03 11:21:43 -08002601 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002602
2603 @staticmethod
2604 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
2605 # IFM: NHWC
2606 # Filter: HWCM
2607 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08002608 h = (
2609 ifm.shape[1]
2610 - filter.shape[0]
2611 - (filter.shape[0] - 1) * (dilations[0] - 1)
2612 + padding[0]
2613 + padding[1]
2614 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002615
Kevin Cheng550ccc52021-03-03 11:21:43 -08002616 w = (
2617 ifm.shape[2]
2618 - filter.shape[1]
2619 - (filter.shape[1] - 1) * (dilations[1] - 1)
2620 + padding[2]
2621 + padding[3]
2622 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002623
2624 if h <= 0 or w <= 0:
2625 # Invalid test parameters?
2626 h = 0
2627 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002628 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002629
2630 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
2631
Kevin Cheng3a478572021-01-22 17:21:02 -08002632 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002633 out_dtype = DType.INT32
2634 elif ifm.dtype == DType.INT16:
2635 out_dtype = DType.INT48
2636 elif ifm.dtype == DType.FLOAT:
2637 out_dtype = DType.FLOAT
2638 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002639 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002640
Kevin Cheng550ccc52021-03-03 11:21:43 -08002641 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002642
2643 @staticmethod
2644 def pool2dOp(ser, ifm, kernel, stride, pad):
2645 # input: NHWC
2646 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
2647 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
2648
2649 if h <= 0 or w <= 0:
2650 # Invalid test parameters?
2651 h = 0
2652 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002653 ser.setExpectedFailure(True, "Invalid combination of pooling parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002654
2655 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002656 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002657
2658 @staticmethod
2659 def fullyConnectedOp(ser, input, filter):
2660 # input: N, IC
2661 # filter: OC, IC
2662 # output: N, OC
2663
2664 output_shape = [input.shape[0], filter.shape[0]]
2665
Kevin Cheng3a478572021-01-22 17:21:02 -08002666 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002667 out_dtype = DType.INT32
2668 elif input.dtype == DType.INT16:
2669 out_dtype = DType.INT48
2670 elif input.dtype == DType.FLOAT:
2671 out_dtype = DType.FLOAT
2672 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002673 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002674
Kevin Cheng550ccc52021-03-03 11:21:43 -08002675 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002676
2677 @staticmethod
2678 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07002679 # a: N, H, C
2680 # b: N, C, W
2681 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07002682
Kevin Cheng2d60f002021-06-09 14:18:32 -07002683 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002684
Kevin Cheng3a478572021-01-22 17:21:02 -08002685 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002686 out_dtype = DType.INT32
2687 elif a.dtype == DType.INT16:
2688 out_dtype = DType.INT48
2689 elif a.dtype == DType.FLOAT:
2690 out_dtype = DType.FLOAT
2691 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002692 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002693
Kevin Cheng550ccc52021-03-03 11:21:43 -08002694 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002695
2696 @staticmethod
2697 def concatOp(ser, a, b, axis):
2698
2699 output_shape = a.shape.copy()
2700 output_shape[axis] = a.shape[axis] + b.shape[axis]
2701
Kevin Cheng550ccc52021-03-03 11:21:43 -08002702 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002703
2704 @staticmethod
2705 def padOp(ser, a, padding):
2706
2707 output_shape = a.shape.copy()
2708
2709 for i in range(len(output_shape)):
2710 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
2711
Kevin Cheng550ccc52021-03-03 11:21:43 -08002712 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002713
2714 @staticmethod
2715 def reshapeOp(ser, a, shape):
2716 output_shape = shape.copy()
2717
2718 totalElements = 1
2719 for i in a.shape:
2720 totalElements *= i
2721
2722 # If there are any -1 elements, figure out what that dimension must be
2723 totalOutputElements = 1
2724 for i in output_shape:
2725 if i != -1:
2726 totalOutputElements *= i
2727
2728 # And fill it in
2729 for i in range(len(output_shape)):
2730 if output_shape[i] == -1:
2731 output_shape[i] = totalElements // totalOutputElements
2732
Kevin Cheng550ccc52021-03-03 11:21:43 -08002733 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002734
2735 @staticmethod
2736 def sliceOp(ser, a, begin, size):
2737
2738 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002739 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002740
2741 @staticmethod
2742 def tileOp(ser, a, multiples):
2743
2744 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002745 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002746
2747 for i in range(len(output_shape)):
2748 output_shape[i] = a.shape[i] * multiples[i]
2749
Kevin Cheng550ccc52021-03-03 11:21:43 -08002750 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002751
2752 @staticmethod
2753 def transposeOp(ser, a, perms):
2754 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002755 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002756
2757 for i in range(len(output_shape)):
2758 output_shape[i] = a.shape[perms[i]]
2759
Kevin Cheng550ccc52021-03-03 11:21:43 -08002760 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002761
2762 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08002763 def gatherOp(ser, values, indices):
2764 assert len(values.shape) == 3
2765 assert len(indices.shape) == 2
2766 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07002767
Kevin Cheng77d0f762020-11-24 10:26:32 -08002768 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
2769
Kevin Cheng550ccc52021-03-03 11:21:43 -08002770 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002771
2772 @staticmethod
2773 def scatterOp(ser, values_in, indices, input):
2774 assert len(values_in.shape) == 3
2775 assert len(indices.shape) == 2
2776 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08002777 assert values_in.shape[0] == indices.shape[0] # N
2778 assert input.shape[1] == indices.shape[1] # W
2779 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08002780
2781 output_shape = values_in.shape
2782
Kevin Cheng550ccc52021-03-03 11:21:43 -08002783 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002784
2785 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002786 def tableOp(ser, input, table_dtype):
2787 # Same shape as the input, but dtype dependent on table dtype
2788 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
2789 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
2790 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002791
2792 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08002793 def resizeOp(
2794 ser,
2795 input,
2796 mode,
2797 stride,
2798 offset,
2799 shift,
2800 stride_fp,
2801 offset_fp,
2802 output_dims,
2803 input_dtype,
2804 output_dtype,
2805 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002806
2807 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
2808
Kevin Cheng77d0f762020-11-24 10:26:32 -08002809 if input_dtype == DType.FLOAT:
2810 if stride_fp[0] <= 0 or stride_fp[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002811 ser.setExpectedFailure(True, "Negative or zero stride")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002812 else:
2813 if stride[0] <= 0 or stride[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002814 ser.setExpectedFailure(True, "Negative or zero stride")
Eric Kunzee5e26762020-10-13 16:11:07 -07002815
Kevin Chengaee1fac2020-11-11 13:54:06 -08002816 if mode == ResizeMode.BILINEAR:
2817 if input_dtype == DType.INT8:
2818 if output_dtype != DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002819 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002820 elif input_dtype == DType.INT16:
2821 if output_dtype != DType.INT48:
Kevin Cheng989cb052021-04-28 16:29:44 -07002822 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002823 elif input_dtype == DType.FLOAT:
2824 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002825 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002826 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002827 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002828
2829 elif mode == ResizeMode.NEAREST:
2830 if input_dtype == DType.INT8:
2831 if output_dtype != DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002832 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002833 elif input_dtype == DType.INT16:
2834 if output_dtype != DType.INT16:
Kevin Cheng989cb052021-04-28 16:29:44 -07002835 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002836 elif input_dtype == DType.FLOAT:
2837 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002838 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002839 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002840 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002841
2842 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002843 ser.setExpectedFailure(true, "Invalid resize mode")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002844
Kevin Cheng550ccc52021-03-03 11:21:43 -08002845 return ser.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002846
2847 @staticmethod
2848 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002849 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002850
2851 @staticmethod
2852 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08002853 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002854 out_dtype = DType.INT32
2855 elif ifm.dtype == DType.INT16:
2856 out_dtype = DType.INT48
2857 elif ifm.dtype == DType.FLOAT:
2858 out_dtype = DType.FLOAT
2859 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002860 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002861
2862 if output_shape[1] <= 0 or output_shape[2] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002863 ser.setExpectedFailure(True, "Negative output shape")
Eric Kunzee5e26762020-10-13 16:11:07 -07002864
Kevin Cheng550ccc52021-03-03 11:21:43 -08002865 return ser.addOutput(output_shape, out_dtype)