blob: 5c25f8e4db7cdfcc3338d83ad503da391c98a691 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001#!/usr/bin/env python3
2
Kevin Cheng3a478572021-01-22 17:21:02 -08003# Copyright (c) 2020-2021, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07004#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18import numpy as np
19import argparse
20import sys
21import re
22import os
23import subprocess
24import shlex
25import json
26import glob
27import math
28import queue
29import threading
30import traceback
31import math
Jeremy Johnsona6185572021-06-21 15:55:35 +010032import itertools
Eric Kunzee5e26762020-10-13 16:11:07 -070033
34from enum import IntEnum, Enum, unique
35
Kevin Cheng550ccc52021-03-03 11:21:43 -080036# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
37parent_dir = os.path.dirname(os.path.realpath(__file__))
38sys.path.append(
39 os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
40)
Eric Kunzee5e26762020-10-13 16:11:07 -070041import tosa_serializer as ts
42from tosa_serializer import *
43import tosa
44
45# Convenience variables to the flatc-generated types that should be enums, but aren't
46DType = tosa.DType.DType()
Kevin Cheng550ccc52021-03-03 11:21:43 -080047Op = tosa.Op.Op()
Eric Kunzee5e26762020-10-13 16:11:07 -070048ResizeMode = tosa.ResizeMode.ResizeMode()
49
Kevin Cheng550ccc52021-03-03 11:21:43 -080050
Eric Kunzee5e26762020-10-13 16:11:07 -070051class TosaQuantGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -080052 """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
53
Eric Kunzee5e26762020-10-13 16:11:07 -070054 def __init__(self):
55 pass
56
57 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010058 def getQinfo(testGen, dtype):
59 if dtype == DType.INT8:
60 return testGen.randInt(-128, 128)
61 if dtype == DType.UINT8:
62 return testGen.randInt(0, 256)
63 return 0
Eric Kunzee5e26762020-10-13 16:11:07 -070064
65 @staticmethod
66 def qgUnary(testGen, op, dtype):
67 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010068 qinfo.UnaryQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
69 TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070070 return qinfo
71
72 @staticmethod
Les Bell30e46802021-07-23 09:43:31 +010073 def qgConv(testGen, op, dtype_or_dtypeList):
Eric Kunzee5e26762020-10-13 16:11:07 -070074 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010075 if isinstance(dtype_or_dtypeList, list):
76 # a list of [input, weights, accumulator] dtypes
77 dtypeList = dtype_or_dtypeList
Eric Kunzee5e26762020-10-13 16:11:07 -070078 else:
Les Bell30e46802021-07-23 09:43:31 +010079 # an int, [input, weights, accumulator] dtypes are the same
80 dtypeList = [dtype_or_dtypeList] * 3
81 input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
82 weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
83 qinfo.ConvQuantInfo(input_zp, weights_zp)
Eric Kunzee5e26762020-10-13 16:11:07 -070084 return qinfo
85
86 @staticmethod
87 def qgMatmul(testGen, op, dtype):
88 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010089 qinfo.MatMulQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
90 TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070091 return qinfo
92
93 @staticmethod
94 def qgPad(testGen, op, dtype):
95 qinfo = ts.TosaSerializerQuantInfo()
Les Bell30e46802021-07-23 09:43:31 +010096 qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -070097 return qinfo
98
99 @staticmethod
100 def computeMultiplierAndShift(scaleFp, scale32):
101 # Derived from computeMultiplierAndShiftTosaScale32
102 # Provide a floating-point scaling factor and the scale32 parameter
103 # to compute the multiplier and shift
104
105 if scale32:
106 scaleBits = 31
107 else:
108 scaleBits = 15
109
110 m, shift = math.frexp(scaleFp)
111
112 if scaleFp < 0.0:
113 m = -m
114
115 multiplier = round(m * (1 << scaleBits))
Kevin Cheng550ccc52021-03-03 11:21:43 -0800116 assert multiplier <= (1 << scaleBits)
Eric Kunzee5e26762020-10-13 16:11:07 -0700117
118 if multiplier == (1 << scaleBits):
119 multiplier = multiplier // 2
120 shift = shift + 1
121
122 shift = (-shift) + scaleBits
Kevin Cheng550ccc52021-03-03 11:21:43 -0800123 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
Eric Kunzee5e26762020-10-13 16:11:07 -0700124
Kevin Cheng550ccc52021-03-03 11:21:43 -0800125 assert multiplier <= (1 << scaleBits)
126 assert shift >= 0 and shift <= 63
Eric Kunzee5e26762020-10-13 16:11:07 -0700127
128 return multiplier, shift
129
130
Kevin Cheng550ccc52021-03-03 11:21:43 -0800131class TosaTensorGen:
132 """Tensor generators create a shape list for the placeholder and const tensor
133 data operands for the operator. The actual random data is generated separately for each test."""
134
Eric Kunzee5e26762020-10-13 16:11:07 -0700135 def __init__(self):
136 pass
137
138 @staticmethod
139 def tgBasic(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800140 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 shape = testGen.makeShape(rank)
142
143 shape_list = []
144 for i in range(pl + const):
145 shape_list.append(shape.copy())
146
147 return shape_list
148
149 @staticmethod
150 def tgNHWC(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800151 pl, const = opName["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700152
Kevin Cheng550ccc52021-03-03 11:21:43 -0800153 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700154
155 shape = testGen.makeShape(rank)
156
157 # Constrict the batch size?
158 if testGen.args.max_batch_size:
159 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
160
161 shape_list = []
162 for i in range(pl + const):
163 shape_list.append(shape.copy())
164
165 return shape_list
166
167 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -0800168 def tgScatter(testGen, opName, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800169 pl, const = opName["operands"]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800170
Kevin Cheng550ccc52021-03-03 11:21:43 -0800171 assert pl == 2
172 assert const == 0
173 assert rank == 3
Kevin Cheng77d0f762020-11-24 10:26:32 -0800174
175 values_in_shape = testGen.makeShape(rank)
176
177 # Constrict the batch size?
178 if testGen.args.max_batch_size:
179 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
180
Kevin Cheng550ccc52021-03-03 11:21:43 -0800181 W = testGen.randInt(
182 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
183 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800184 input_shape = [values_in_shape[0], W, values_in_shape[2]]
185
186 shape_list = []
187 shape_list.append(values_in_shape.copy())
188 shape_list.append(input_shape.copy())
189
190 return shape_list
191
192 @staticmethod
Eric Kunzee5e26762020-10-13 16:11:07 -0700193 def tgBroadcastFuzz(testGen, op, rank):
194 shape = testGen.makeShape(rank)
195
Kevin Cheng550ccc52021-03-03 11:21:43 -0800196 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700197
198 shape_list = []
199
200 # Choose one of the inputs to broadcast
201 bcast_idx = testGen.randInt(0, pl + const)
202 for i in range(pl + const):
203 shape_bcast = shape.copy()
204
205 # If the chosen input, pick a random index to broadcast
206 if i == bcast_idx:
207 fuzz_idx = testGen.randInt(0, rank)
208 shape_bcast[fuzz_idx] = 1
209
210 shape_list.append(shape_bcast)
211
212 return shape_list
213
214 @staticmethod
215 def tgConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800216 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700217
Kevin Cheng550ccc52021-03-03 11:21:43 -0800218 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700219
220 # IFM dimensions are NHWC
221 ifm_shape = testGen.makeShape(rank)
222
223 # Constrict the batch size?
224 if testGen.args.max_batch_size:
225 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
226
227 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800228 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700229
230 # Generate a random OFM depth
231 ofm_depth = testGen.makeShape(1)[0]
232
233 # The filter dimensions are OHWI
234 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
235
236 # The bias is OC
237 bias_shape = np.asarray([ofm_depth])
238
239 return [ifm_shape, filter_shape, bias_shape]
240
241 @staticmethod
242 def tgTransposeConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800243 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700244
Kevin Cheng550ccc52021-03-03 11:21:43 -0800245 assert rank == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700246
247 # IFM dimensions are NHWC
248 ifm_shape = testGen.makeShape(rank)
249
250 # Constrict the batch size?
251 if testGen.args.max_batch_size:
252 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
253
254 # Get the filter height/width from the operator parameters
Kevin Cheng550ccc52021-03-03 11:21:43 -0800255 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700256
257 # Generate a random OFM depth
258 ofm_depth = testGen.makeShape(1)[0]
259
260 # The filter dimensions are OHWI
261 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
262
Kevin Cheng989cb052021-04-28 16:29:44 -0700263 # The bias is OC
264 bias_shape = np.asarray([ofm_depth])
265
266 return [ifm_shape, filter_shape, bias_shape]
Eric Kunzee5e26762020-10-13 16:11:07 -0700267
268 @staticmethod
269 def tgDepthwiseConv2D(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800270 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700271
Kevin Cheng550ccc52021-03-03 11:21:43 -0800272 assert rank == 4
273 assert pl == 1 and const == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700274
275 # IFM dimensions are NHWC
276 ifm_shape = testGen.makeShape(rank)
277
278 # Constrict the batch size?
279 if testGen.args.max_batch_size:
280 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
281
282 # Get the filter height/width from the operator parameters
283 # Filter is KH, HW, C, M
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 filter_hw = op["filter"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700285
286 # Generate a random OFM depth, but don't let it get too big because
287 # the output depth is M * C
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 filter_m = (
289 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
290 ) + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700291
292 # The filter dimensions are HWCM
293 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
294
295 # The bias is M * C
296 bias_shape = np.asarray([ifm_shape[3] * filter_m])
297
298 return [ifm_shape, filter_shape, bias_shape]
299
300 @staticmethod
301 def tgFullyConnected(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800302 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700303
Kevin Cheng550ccc52021-03-03 11:21:43 -0800304 assert rank == 2
Eric Kunzee5e26762020-10-13 16:11:07 -0700305
306 input_shape = testGen.makeShape(rank)
Les Bell2a29dc62021-07-28 08:04:55 +0100307 filter_oc = (
308 testGen.rng.integers(
309 low=testGen.args.tensor_shape_range[0],
310 high=testGen.args.tensor_shape_range[1],
311 size=1,
312 )[0]
313 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700314 filter_shape = np.asarray([filter_oc, input_shape[1]])
315
316 bias_shape = np.asarray([filter_oc])
317
318 return [input_shape, filter_shape, bias_shape]
319
320 @staticmethod
321 def tgMatmul(testGen, op, rank):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800322 pl, const = op["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -0700323
Kevin Cheng2d60f002021-06-09 14:18:32 -0700324 assert rank == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -0800325 assert pl == 2 and const == 0
Eric Kunzee5e26762020-10-13 16:11:07 -0700326
327 a_shape = testGen.makeShape(rank)
328 b_oc = testGen.makeShape(1)[0]
Kevin Cheng2d60f002021-06-09 14:18:32 -0700329 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
Eric Kunzee5e26762020-10-13 16:11:07 -0700330
331 return [a_shape, b_shape]
332
Kevin Cheng550ccc52021-03-03 11:21:43 -0800333
Eric Kunzee5e26762020-10-13 16:11:07 -0700334class TosaArgGen:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800335 """Argument generators create exhaustive or random lists of attributes for operators that take
336 attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
337 tuples where the descriptive_name is appended to the test name and the arglist is expanded
338 as arguments to the operator build function."""
339
Eric Kunzee5e26762020-10-13 16:11:07 -0700340 def __init__(self):
341 pass
342
343 @staticmethod
344 def agNone(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800345 """A trivial argument generator for operators that don't take any
346 non-tensor arguments"""
347 return [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -0700348
349 @staticmethod
350 def agAxis(testGen, opName, shapeList, dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800351 """Build the axis argument for operators that take a single axis"""
Eric Kunzee5e26762020-10-13 16:11:07 -0700352 axes = []
353
354 shape = shapeList[0]
355
356 for a in range(0, len(shape)):
Matthew Haddon43e37192021-07-09 14:13:02 +0100357 axes.append(("axis{}".format(a), [a]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700358 return axes
359
360 @staticmethod
361 def agConv2D(testGen, opName, shapeList, dtype):
362 arg_list = []
363
364 ifm_shape = shapeList[0]
365 filter_shape = shapeList[1]
366
367 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800368 assert len(ifm_shape) == 4
369 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700370
371 maxStride = testGen.args.max_conv_stride
372 maxPadding = testGen.args.max_conv_padding + 1
373 maxDilation = testGen.args.max_conv_dilation
374
375 # Strides, padding, dilations
376 for stride in range(0, maxStride ** 2):
377 for padding in range(0, (maxPadding) ** 4):
378 for dilation in range(0, maxDilation ** 2):
379
Kevin Cheng550ccc52021-03-03 11:21:43 -0800380 s = [stride // maxStride + 1, stride % maxStride + 1]
381 p = [
382 (padding // (maxPadding * 4)) % maxPadding,
383 (padding // (maxPadding * 2)) % maxPadding,
384 (padding // (maxPadding * 1)) % maxPadding,
385 padding % maxPadding,
386 ]
387 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700388
389 # 4 padding parameters for regular conv2d
Kevin Cheng550ccc52021-03-03 11:21:43 -0800390 arg_list.append(
391 (
392 "st{}{}_pad{}{}{}{}_dilat{}{}".format(
393 s[0], s[1], p[0], p[1], p[2], p[3], d[0], d[1]
394 ),
395 [s, p, d],
396 )
397 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700398 return arg_list
399
400 @staticmethod
401 def agTransposeConv2D(testGen, opName, shapeList, dtype):
402 arg_list = []
403
404 ifm_shape = shapeList[0]
405 filter_shape = shapeList[1]
406
407 # Must be rank 4
Kevin Cheng550ccc52021-03-03 11:21:43 -0800408 assert len(ifm_shape) == 4
409 assert len(filter_shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700410
411 maxStride = testGen.args.max_conv_stride
412 maxPadding = testGen.args.max_conv_padding + 1
413 maxDilation = testGen.args.max_conv_dilation
414
415 # Strides, padding, dilations
416 for stride in range(0, maxStride ** 2):
417 for out_padding in range(0, (maxPadding) ** 2):
418 for dilation in range(0, maxDilation ** 2):
419
Kevin Cheng550ccc52021-03-03 11:21:43 -0800420 s = [stride // maxStride + 1, stride % maxStride + 1]
421 p = [
422 (out_padding // (maxPadding * 1)) % maxPadding,
423 out_padding % maxPadding,
424 ]
425 d = [dilation // maxDilation + 1, dilation % maxDilation + 1]
Eric Kunzee5e26762020-10-13 16:11:07 -0700426
Kevin Cheng550ccc52021-03-03 11:21:43 -0800427 oh = (
428 ifm_shape[1]
429 - filter_shape[1]
430 - (filter_shape[1] - 1) * (d[0] - 1)
431 + 2 * p[0]
432 ) // s[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700433
Kevin Cheng550ccc52021-03-03 11:21:43 -0800434 ow = (
435 ifm_shape[2]
436 - filter_shape[2]
437 - (filter_shape[2] - 1) * (d[1] - 1)
438 + 2 * p[1]
439 ) // s[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -0700440
441 # Output shape
Kevin Cheng550ccc52021-03-03 11:21:43 -0800442 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Eric Kunzee5e26762020-10-13 16:11:07 -0700443
Kevin Cheng550ccc52021-03-03 11:21:43 -0800444 arg_list.append(
445 (
446 "st{}{}_outpad{}{}_dilat{}{}_os{}x{}x{}x{}".format(
447 s[0],
448 s[1],
449 p[0],
450 p[1],
451 d[0],
452 d[1],
453 os[0],
454 os[1],
455 os[2],
456 os[3],
457 ),
458 [s, p, d, os],
459 )
460 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700461
462 return arg_list
463
464 @staticmethod
465 def agPad(testGen, opName, shapeList, dtype):
466 arg_list = []
467 rank = len(shapeList[0])
468
Les Bell7ffccce2021-07-28 15:37:02 +0100469 # Exhaustively test combinations of padding on each side of each dimension
470 # - the range of padding values is defined by pad_min and pad_max
471 # - for padding >9, the name format needs to be more distinctive
472 pad_min, pad_max = 0, 1
473 pad_values = [x for x in range(pad_min, pad_max + 1)]
474 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
475 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
Eric Kunzee5e26762020-10-13 16:11:07 -0700476
Les Bell7ffccce2021-07-28 15:37:02 +0100477 for paddings in shape_pad_values:
478 name = "pad"
479 for r in range(rank):
480 before, after = paddings[r]
481 name = f"{name}{before}{after}"
482 arg_list.append((name, [np.array(paddings)]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700483
484 return arg_list
485
486 @staticmethod
487 def agPooling(testGen, opName, shapeList, dtype):
488 arg_list = []
489
490 shape = shapeList[0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800491 assert len(shape) == 4
Eric Kunzee5e26762020-10-13 16:11:07 -0700492
493 maxStride = testGen.args.max_pooling_stride
494 maxKernel = testGen.args.max_pooling_kernel
495 maxPadding = testGen.args.max_pooling_padding + 1
496
497 for kernel in range(0, maxKernel ** 2):
498 for stride in range(0, maxStride ** 2):
499 for padding in range(0, maxPadding ** 4):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800500 s = [stride // maxStride + 1, stride % maxStride + 1]
501 k = [(kernel // maxKernel) + 2, (kernel % maxKernel) + 2]
502 p = [
503 (padding // (maxPadding * 4)) % maxPadding,
504 (padding // (maxPadding * 2)) % maxPadding,
505 (padding // (maxPadding * 1)) % maxPadding,
506 padding % maxPadding,
507 ]
Eric Kunzee5e26762020-10-13 16:11:07 -0700508
Kevin Cheng550ccc52021-03-03 11:21:43 -0800509 arg_list.append(
510 (
511 "st{}{}_kern{}{}_pad{}{}{}{}".format(
512 s[0], s[1], k[0], k[1], p[0], p[1], p[2], p[3]
513 ),
514 [k, s, p],
515 )
516 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700517 return arg_list
518
519 @staticmethod
520 def agCast(testGen, opName, shapeList, inDtype):
521 arg_list = []
522
523 # Enumerate the output types here
524 if inDtype == DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800525 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700526 elif inDtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800527 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700528 elif inDtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800529 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700530 elif inDtype == DType.BOOL:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800531 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700532 elif inDtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800533 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700534 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800535 raise Exception("Unexpected input dtype: {}".format(inDtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700536
537 for dtype in dtypeList:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800538 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700539
540 return arg_list
541
542 @staticmethod
543 def agRescale(testGen, opName, shapeList, inDtype):
544 arg_list = []
545
546 # Enumerate the output types here
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100547 for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
548 if inDtype == DType.UINT8 and dtype != DType.INT8:
549 # The only output dtype for UINT8 is INT8, skip all other combinations
550 continue
551 if inDtype != DType.INT8 and dtype == DType.UINT8:
552 # The only input dtype for UINT8 is INT8, skip all other combinations
553 continue
554
Kevin Cheng550ccc52021-03-03 11:21:43 -0800555 for scale32 in [False, True]:
556 for double_round in [False, True]:
557 for per_channel in [False, True]:
Eric Kunzee5e26762020-10-13 16:11:07 -0700558
559 if inDtype == DType.INT48 and scale32:
560 # Illegal condition. Must be scale32=False
561 continue
Matthew Haddoncac4ee92021-07-22 14:30:53 +0100562 if double_round and not scale32:
563 # Illegal condition. ERROR_IF(!scale32 && double_round)
564 continue
Eric Kunzee5e26762020-10-13 16:11:07 -0700565
Kevin Cheng550ccc52021-03-03 11:21:43 -0800566 arg_list.append(
567 (
568 "out{}_sc{}_dr{}_pc{}".format(
569 DTypeNames[dtype],
570 int(scale32),
571 int(double_round),
572 int(per_channel),
573 ),
574 [dtype, scale32, double_round, per_channel],
575 )
576 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700577
578 return arg_list
579
Kevin Chengaee1fac2020-11-11 13:54:06 -0800580 @staticmethod
581 def agMul(testGen, opName, shapeList, dtype):
582 arg_list = []
583
584 if dtype is DType.INT32:
585 for p in range(testGen.args.num_rand_permutations):
586
587 shift = testGen.randInt(0, 32)
588
Kevin Cheng550ccc52021-03-03 11:21:43 -0800589 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800590 else:
Matthew Haddon43e37192021-07-09 14:13:02 +0100591 arg_list.append(("perm0_shift0", [0]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800592
593 return arg_list
594
595 @staticmethod
596 def agArithmeticRightShift(testGen, opName, shapeList, dtype):
597 arg_list = []
598
Kevin Cheng550ccc52021-03-03 11:21:43 -0800599 arg_list.append(("roundTrue", [True]))
600 arg_list.append(("roundFalse", [False]))
Kevin Chengaee1fac2020-11-11 13:54:06 -0800601
602 return arg_list
603
Eric Kunzee5e26762020-10-13 16:11:07 -0700604 # Helper function for reshape. Gets some factors of a larger number.
605 @staticmethod
606 def getFactors(val, start=1):
607 factors = []
608
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100609 for i in range(start, int(np.sqrt(val)) + 1):
Eric Kunzee5e26762020-10-13 16:11:07 -0700610 if (val % i) == 0:
611 factors.append(i)
612
613 return factors
614
615 @staticmethod
616 def agReshape(testGen, opName, shapeList, dtype):
617 arg_list = []
618
619 origShape = shapeList[0]
620
621 totalElements = 1
622 for s in origShape:
623 totalElements *= s
624
625 # This code is NOT fast. Fortunately, the numbers are fairly small.
626 factors = TosaArgGen.getFactors(totalElements)
627
628 for p in range(testGen.args.num_rand_permutations):
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100629 newRank = testGen.randInt(1, 7)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800630 if len(factors) < newRank:
Eric Kunzee5e26762020-10-13 16:11:07 -0700631 continue
632
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100633 found = True
634 # escape_counter breaks while loop if it continues on for too long
635 escape_counter = 0
636 while found:
637 newShape = []
638 # Generate newShape ensuring it isn't a duplicate
639 remainingElements = totalElements
640 shuffledFactors = testGen.rng.permutation(factors)
Matthew Haddon5fc4e682021-07-07 11:28:29 +0100641 for i in range(1, newRank):
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100642 # pick rank-1 factors
643 newShape.append(shuffledFactors[0])
644 remainingElements = remainingElements // shuffledFactors[0]
645 shuffledFactors = testGen.rng.permutation(
646 TosaArgGen.getFactors(remainingElements)
647 )
648 newShape.append(remainingElements)
Eric Kunzee5e26762020-10-13 16:11:07 -0700649
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100650 # Toss in a -1 sometimes
651 minusOne = testGen.randInt(0, newRank * 4)
652 if minusOne < newRank:
653 newShape[minusOne] = -1
Eric Kunzee5e26762020-10-13 16:11:07 -0700654
Matthew Haddon2ad047d2021-06-22 16:55:23 +0100655 # Check for duplicates
656 found = False
657 for name, other_shape in arg_list:
658 if other_shape[0] == newShape:
659 found = True
660 break
661
662 escape_counter += 1
663 if escape_counter >= 100:
664 break
665
666 if not found:
667 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700668
669 return arg_list
670
Eric Kunzee5e26762020-10-13 16:11:07 -0700671 @staticmethod
672 def agTranspose(testGen, opName, shapeList, dtype):
673 arg_list = []
674
675 ifm_shape = shapeList[0]
676
Jeremy Johnsona6185572021-06-21 15:55:35 +0100677 # Get all permutations
678 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
Eric Kunzee5e26762020-10-13 16:11:07 -0700679
Jeremy Johnsona6185572021-06-21 15:55:35 +0100680 # Limit to possible permutations from shape dimension or argument setting
681 limit = min(len(permutations), testGen.args.num_rand_permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700682
Jeremy Johnsona6185572021-06-21 15:55:35 +0100683 # Get random permutation generator that uses all permutations
684 random_permutations = testGen.rng.permutation(permutations)
Eric Kunzee5e26762020-10-13 16:11:07 -0700685
Jeremy Johnsona6185572021-06-21 15:55:35 +0100686 # Create list of required amount of permutations
687 arg_list = [("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit)]
Eric Kunzee5e26762020-10-13 16:11:07 -0700688 return arg_list
689
690 @staticmethod
691 def agSlice(testGen, opName, shapeList, dtype):
692 arg_list = []
693
694 ifm_shape = shapeList[0]
695 rank = len(ifm_shape)
696
697 for p in range(testGen.args.num_rand_permutations):
698 begin = []
699 size = []
700
Kevin Cheng550ccc52021-03-03 11:21:43 -0800701 valid = True
Eric Kunzee5e26762020-10-13 16:11:07 -0700702
703 for i in range(rank):
704 if ifm_shape[i] > 1:
705 begin.append(testGen.randInt(0, ifm_shape[i]))
706 size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
707
708 # Invalid slice size?
709 if size[i] == 0:
710 valid = False
711 else:
712 begin.append(0)
713 size.append(1)
714
715 if valid:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800716 arg_list.append(("perm{}".format(p), [begin, size]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700717 return arg_list
718
719 @staticmethod
720 def agTile(testGen, opName, shapeList, dtype):
721 arg_list = []
722
723 ifm_shape = shapeList[0]
724 rank = len(ifm_shape)
725
726 for p in range(testGen.args.num_rand_permutations):
727
728 # Pick a few random, but small multiple values
729 # because otherwise this has a tendency to generate
730 # enormous tensors
731 multiples = []
732 for i in range(rank):
733 multiples.append(testGen.randInt(1, 4))
734
Kevin Cheng550ccc52021-03-03 11:21:43 -0800735 arg_list.append(("perm{}".format(p), [multiples]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700736
737 return arg_list
738
739 @staticmethod
740 def agResize(testGen, opName, shapeList, dtype):
741 arg_list = []
742
743 ifm_shape = shapeList[0]
744
745 for m in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
746
747 # Exclude illegal {mode, type} configurations. Pick legal output types
748 if m == ResizeMode.NEAREST and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100749 outputDTypeList = [DType.INT8]
Eric Kunzee5e26762020-10-13 16:11:07 -0700750 elif m == ResizeMode.NEAREST and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800751 outputDTypeList = [DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -0700752 elif m == ResizeMode.BILINEAR and dtype == DType.INT8:
Les Bell33d837e2021-08-10 08:34:43 +0100753 outputDTypeList = [DType.INT32]
Eric Kunzee5e26762020-10-13 16:11:07 -0700754 elif m == ResizeMode.BILINEAR and dtype == DType.INT16:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800755 outputDTypeList = [DType.INT48]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800756 elif dtype == DType.FLOAT:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800757 outputDTypeList = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -0700758 else:
759 continue
760
761 for outputDType in outputDTypeList:
762 for perm in range(testGen.args.num_rand_permutations):
763
764 # Randomly generate legal output dimensions and shift
765 # and then compute the stride and offset based on them
Kevin Cheng550ccc52021-03-03 11:21:43 -0800766 output_dims = [testGen.randInt(1), testGen.randInt(1)]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800767 in_center_h = (ifm_shape[1] - 1) / 2.0
768 in_center_w = (ifm_shape[2] - 1) / 2.0
769 out_center_h = (output_dims[0] - 1) / 2.0
770 out_center_w = (output_dims[1] - 1) / 2.0
Eric Kunzee5e26762020-10-13 16:11:07 -0700771
Kevin Cheng77d0f762020-11-24 10:26:32 -0800772 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
773 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
774 fp_offset_y = in_center_h - fp_stride_y * out_center_h
775 fp_offset_x = in_center_w - fp_stride_x * out_center_w
Eric Kunzee5e26762020-10-13 16:11:07 -0700776
Kevin Cheng77d0f762020-11-24 10:26:32 -0800777 if outputDType == DType.FLOAT:
778 shift = 0
779 stride = [0, 0]
780 offset = [0, 0]
Kevin Cheng550ccc52021-03-03 11:21:43 -0800781 stride_fp = [fp_stride_y, fp_stride_x]
782 offset_fp = [fp_offset_y, fp_offset_x]
783 arg_list.append(
784 (
785 "mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
Les Bell33d837e2021-08-10 08:34:43 +0100786 "N" if m == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800787 output_dims[0],
788 output_dims[1],
789 testGen.typeStr(outputDType),
790 stride_fp[0],
791 stride_fp[1],
792 offset_fp[0],
793 offset_fp[1],
794 ),
795 [
796 m,
797 stride,
798 offset,
799 shift,
800 stride_fp,
801 offset_fp,
802 output_dims,
803 dtype,
804 outputDType,
805 ],
806 )
807 )
Kevin Cheng77d0f762020-11-24 10:26:32 -0800808 else:
809 shift = 11
810 unit = float(1 << shift)
811 stride_y = int(round(fp_stride_y * unit))
812 stride_x = int(round(fp_stride_x * unit))
813 offset_y = int(round(fp_offset_y * unit))
814 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700815
Kevin Cheng550ccc52021-03-03 11:21:43 -0800816 while (
817 stride_y >= 32768
818 or stride_x >= 32768
819 or offset_y >= 32768
820 or offset_x >= 32768
821 or offset_y < -32768
822 or offset_x < -32768
823 ):
Kevin Cheng77d0f762020-11-24 10:26:32 -0800824 shift = shift - 1
825 unit = float(1 << shift)
826 stride_y = int(round(fp_stride_y * unit))
827 stride_x = int(round(fp_stride_x * unit))
828 offset_y = int(round(fp_offset_y * unit))
829 offset_x = int(round(fp_offset_x * unit))
Eric Kunzee5e26762020-10-13 16:11:07 -0700830
Kevin Cheng550ccc52021-03-03 11:21:43 -0800831 stride = [stride_y, stride_x]
832 offset = [offset_y, offset_x]
Kevin Cheng77d0f762020-11-24 10:26:32 -0800833
834 stride_fp = [0.0, 0.0]
835 offset_fp = [0.0, 0.0]
836
Kevin Cheng550ccc52021-03-03 11:21:43 -0800837 arg_list.append(
838 (
839 "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
Les Bell33d837e2021-08-10 08:34:43 +0100840 "N" if m == ResizeMode.NEAREST else "B",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800841 shift,
842 output_dims[0],
843 output_dims[1],
844 testGen.typeStr(outputDType),
845 stride[0],
846 stride[1],
847 offset[0],
848 offset[1],
849 ),
850 [
851 m,
852 stride,
853 offset,
854 shift,
855 stride_fp,
856 offset_fp,
857 output_dims,
858 dtype,
859 outputDType,
860 ],
861 )
862 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700863
864 return arg_list
865
866 def agCondIf(testGen, opName, shapeList, dtype):
867 # CondIf generates the condition values here.
868 # Convert to tensors in the build function, along with the
869 # then and else blocks
870 arg_list = []
871
872 for c in [False, True]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800873 arg_list.append(("cond{}".format(int(c)), [c]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700874
875 return arg_list
876
877 def agWhileLoop(testGen, opName, shapeList, dtype):
878 # While loop: 0 iterations, 1, more than 1
879 arg_list = []
880
881 for iter in [0, 1, 4]:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800882 arg_list.append(("iter{}".format(iter), [iter]))
Eric Kunzee5e26762020-10-13 16:11:07 -0700883
884 return arg_list
885
Kevin Cheng550ccc52021-03-03 11:21:43 -0800886
Eric Kunzee5e26762020-10-13 16:11:07 -0700887class TosaTestGen:
Jeremy Johnson97eb75f2021-07-08 11:58:02 +0100888 # Maximum rank of tensor supported by test generator.
889 TOSA_TENSOR_MAX_RANK = 6
890
Eric Kunzee5e26762020-10-13 16:11:07 -0700891 def __init__(self, args):
892 self.args = args
893 self.basePath = args.output_dir
894 self.random_seed = args.random_seed
895 self.ser = None
896 self.rng = np.random.default_rng(self.random_seed)
897 self.createDynamicOpLists()
898 self.initOpListDefaults()
899 self.quantGen = TosaQuantGen()
900 # Force makeShape to do a specific starting shape
901 self.targetted_shape = None
902
903 def createSerializer(self, opName, testPath):
904 self.testPath = os.path.join(opName, testPath)
905
906 fullPath = os.path.join(self.basePath, self.testPath)
907 os.makedirs(fullPath, exist_ok=True)
908 self.ser = ts.TosaSerializer(fullPath)
909
910 def getSerializer(self):
911 return self.ser
912
913 def serialize(self, testName):
Kevin Cheng550ccc52021-03-03 11:21:43 -0800914 with open(
915 os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
916 ) as fd:
Eric Kunzee5e26762020-10-13 16:11:07 -0700917 fd.write(self.ser.serialize())
918
Kevin Cheng550ccc52021-03-03 11:21:43 -0800919 with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
920 fd.write(self.ser.writeJson("{}.tosa".format(testName)))
Eric Kunzee5e26762020-10-13 16:11:07 -0700921
922 def getRandTensor(self, shape, dtype):
Eric Kunzee5e26762020-10-13 16:11:07 -0700923 if dtype == DType.BOOL:
924 np_dt = np.bool
925 return np.bool_(self.rng.choice(a=[False, True], size=shape))
Kevin Chenga9017402021-07-28 17:19:23 -0700926 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700927 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700928 return np.int32(self.rng.integers(low=-7, high=8, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700929 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100930 return np.int32(self.rng.integers(low=-128, high=128, size=shape))
931 elif dtype == DType.UINT8:
932 return np.int32(self.rng.integers(low=0, high=256, size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700933 elif dtype == DType.INT16:
934 return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
935 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800936 return np.int32(
937 self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
938 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700939 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800940 return np.int64(
941 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
942 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700943 elif dtype == DType.FLOAT:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100944 return np.float32(self.rng.random(size=shape))
Eric Kunzee5e26762020-10-13 16:11:07 -0700945 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800946 raise Exception("Unrecognized Dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -0700947
Kevin Cheng989cb052021-04-28 16:29:44 -0700948 def buildPlaceholderTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700949 placeholders = []
950
Kevin Cheng989cb052021-04-28 16:29:44 -0700951 assert len(shape_list) == len(dtype_list)
952
953 for idx, shape in enumerate(shape_list):
954 arr = self.getRandTensor(shape, dtype_list[idx])
955 placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700956
957 return placeholders
958
Kevin Cheng989cb052021-04-28 16:29:44 -0700959 def buildConstTensors(self, shape_list, dtype_list):
Eric Kunzee5e26762020-10-13 16:11:07 -0700960 consts = []
961
Kevin Cheng989cb052021-04-28 16:29:44 -0700962 assert len(shape_list) == len(dtype_list)
963
964 for idx, shape in enumerate(shape_list):
965 arr = self.getRandTensor(shape, dtype_list[idx])
966 consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
Eric Kunzee5e26762020-10-13 16:11:07 -0700967
968 return consts
969
970 def makeShape(self, rank):
971 if self.targetted_shape:
972 return np.int32(self.targetted_shape)
Kevin Cheng550ccc52021-03-03 11:21:43 -0800973 return np.int32(
974 self.rng.integers(
975 low=self.args.tensor_shape_range[0],
976 high=self.args.tensor_shape_range[1],
977 size=rank,
978 )
979 )
Eric Kunzee5e26762020-10-13 16:11:07 -0700980
981 def setTargetShape(self, shape):
982 self.targetted_shape = shape
983
984 def randInt(self, low=0, high=256):
985 return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
986
987 def getRandNumberDType(self, dtype):
988 if dtype == DType.FLOAT:
989 return self.rng.random()
990 elif dtype == DType.BOOL:
991 return self.rng.choice([False, True])
Kevin Chenga9017402021-07-28 17:19:23 -0700992 # TOSA specific INT4 weight range from -7 to 7
Eric Kunzee5e26762020-10-13 16:11:07 -0700993 elif dtype == DType.INT4:
Kevin Chenga9017402021-07-28 17:19:23 -0700994 low, high = (-7, 8)
Eric Kunzee5e26762020-10-13 16:11:07 -0700995 elif dtype == DType.INT8:
Jeremy Johnson18e26662021-07-22 16:15:29 +0100996 low, high = (-128, 128)
Eric Kunzee5e26762020-10-13 16:11:07 -0700997 elif dtype == DType.INT16:
998 low, high = (-32768, 32768)
999 elif dtype == DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001000 low, high = (-(1 << 31), (1 << 31))
Eric Kunzee5e26762020-10-13 16:11:07 -07001001 elif dtype == DType.INT48:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001002 low, high = (-(1 << 47), (1 << 47))
Eric Kunzee5e26762020-10-13 16:11:07 -07001003 # Special size
1004 return np.int64(self.rng.integers(low, high, size=1))[0]
1005 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001006 raise Exception("Unknown dtype: {}".format(dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07001007
1008 return np.int32(self.rng.integers(low, high, size=1))[0]
1009
1010 def shapeStr(self, shape):
1011
1012 sStr = []
1013 # Convert to strings
1014 for i in shape:
1015 sStr.append(str(i))
1016
Kevin Cheng550ccc52021-03-03 11:21:43 -08001017 return "x".join(sStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001018
1019 def typeStr(self, t):
Kevin Cheng989cb052021-04-28 16:29:44 -07001020 if isinstance(t, list):
1021 assert len(t) >= 2
1022 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001023 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001024 if t == DType.BOOL:
1025 return "b"
1026 elif t == DType.INT4:
1027 return "i4"
1028 elif t == DType.INT8:
1029 return "i8"
1030 elif t == DType.UINT8:
1031 return "u8"
1032 elif t == DType.INT16:
1033 return "i16"
1034 elif t == DType.INT32:
1035 return "i32"
1036 elif t == DType.INT48:
1037 return "i48"
1038 elif t == DType.FLOAT:
1039 return "float"
1040 else:
1041 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001042
1043 def typeWidth(self, t):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001044 """ Get the datatype width for integer types"""
Kevin Cheng3a478572021-01-22 17:21:02 -08001045 if t == DType.INT4:
Eric Kunzee5e26762020-10-13 16:11:07 -07001046 return 4
1047 elif t == DType.INT8:
1048 return 8
Kevin Cheng3a478572021-01-22 17:21:02 -08001049 elif t == DType.UINT8:
1050 return 8
Eric Kunzee5e26762020-10-13 16:11:07 -07001051 elif t == DType.INT16:
1052 return 16
1053 elif t == DType.INT32:
1054 return 32
1055 elif t == DType.INT48:
1056 return 48
1057 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001058 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
Eric Kunzee5e26762020-10-13 16:11:07 -07001059
1060 # Argument generators
1061 # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
1062 # Where the string descriptor is used to generate the test name and
1063 # The build_fcn_arg_list is expanded and passed to the operator test
1064 # build function
1065
Kevin Cheng550ccc52021-03-03 11:21:43 -08001066 def build_unary(self, op, a, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001067 result_tens = OutputShaper.unaryOp(self.ser, a)
1068 self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
1069 return result_tens
1070
1071 def build_binary_broadcast(self, op, a, b):
1072 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1073 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1074 return result_tens
1075
1076 def build_binary_nonbroadcast(self, op, a, b):
1077 result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
1078 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1079 return result_tens
1080
Kevin Chengaee1fac2020-11-11 13:54:06 -08001081 def build_arithmetic_right_shift(self, op, a, b, round):
1082 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1083
1084 attr = ts.TosaSerializerAttribute()
1085 attr.ArithmeticRightShiftAttribute(round)
1086
1087 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1088 return result_tens
1089
1090 def build_mul(self, op, a, b, shift):
Eric Kunzee5e26762020-10-13 16:11:07 -07001091 result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
1092
1093 # Special for multiply:
1094 # Force the result to INT32 for INT types
1095 if a.dtype != DType.FLOAT:
1096 result_tens.setDtype(DType.INT32)
1097
Kevin Chengaee1fac2020-11-11 13:54:06 -08001098 attr = ts.TosaSerializerAttribute()
1099 attr.MulAttribute(shift)
1100
1101 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001102 return result_tens
1103
1104 def build_table(self, op, a):
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001105 # Constant size depending on type, random values
1106 if a.dtype == DType.INT16:
1107 table_dtype = DType.INT16
1108 table_arr = self.getRandTensor([513], table_dtype)
1109 else:
1110 assert a.dtype == DType.INT8
1111 table_dtype = DType.INT8
1112 table_arr = self.getRandTensor([256], table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001113
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01001114 table_tens = self.ser.addConst(table_arr.shape, table_dtype, table_arr)
1115 result_tens = OutputShaper.tableOp(self.ser, a, table_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001116 self.ser.addOperator(op, [a.name, table_tens.name], [result_tens.name], None)
1117
1118 return result_tens
1119
1120 def build_select(self, op, cond, a, b):
Eric Kunzee5e26762020-10-13 16:11:07 -07001121 result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
1122 self.ser.addOperator(op, [cond.name, a.name, b.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001123 return result_tens
1124
1125 def build_comparison(self, op, a, b):
1126 result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
1127 self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
1128 return result_tens
1129
1130 def build_argmax(self, op, a, axis):
1131 result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
1132
1133 attr = ts.TosaSerializerAttribute()
1134 attr.AxisAttribute(axis)
1135
1136 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1137 return result_tens
1138
Kevin Cheng550ccc52021-03-03 11:21:43 -08001139 def build_pool2d(self, op, input, kernel, stride, pad, qinfo=None):
Eric Kunzee5e26762020-10-13 16:11:07 -07001140 result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
1141
1142 attr = ts.TosaSerializerAttribute()
1143 attr.Pool2dAttribute(kernel, stride, pad)
Eric Kunzee5e26762020-10-13 16:11:07 -07001144
1145 self.ser.addOperator(op, [input.name], [result_tens.name], attr, qinfo)
1146 return result_tens
1147
1148 def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001149 assert len(padding) == 4
1150 result_tens = OutputShaper.conv2dOp(
1151 self.ser, ifm, filter, strides, padding, dilations
1152 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001153
1154 attr = ts.TosaSerializerAttribute()
1155 attr.Conv2dAttribute(padding, strides, dilations)
1156
Kevin Cheng550ccc52021-03-03 11:21:43 -08001157 self.ser.addOperator(
1158 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1159 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001160 return result_tens
1161
Kevin Cheng550ccc52021-03-03 11:21:43 -08001162 def build_transpose_conv2d(
Kevin Cheng989cb052021-04-28 16:29:44 -07001163 self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001164 ):
1165 assert len(outpad) == 2
Eric Kunzee5e26762020-10-13 16:11:07 -07001166 result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
1167
1168 attr = ts.TosaSerializerAttribute()
1169 attr.TransposeConv2DAttribute(outpad, stride, dilation, output_shape)
1170
Kevin Cheng550ccc52021-03-03 11:21:43 -08001171 self.ser.addOperator(
Kevin Cheng989cb052021-04-28 16:29:44 -07001172 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
Kevin Cheng550ccc52021-03-03 11:21:43 -08001173 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001174 return result_tens
1175
Kevin Cheng550ccc52021-03-03 11:21:43 -08001176 def build_depthwise_conv2d(
1177 self, op, ifm, filter, bias, strides, padding, dilations, qinfo
1178 ):
1179 result_tens = OutputShaper.depthwiseConv2dOp(
1180 self.ser, ifm, filter, strides, padding, dilations
1181 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001182
1183 attr = ts.TosaSerializerAttribute()
1184 attr.Conv2dAttribute(padding, strides, dilations)
1185
Kevin Cheng550ccc52021-03-03 11:21:43 -08001186 self.ser.addOperator(
1187 op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
1188 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001189 return result_tens
1190
1191 def build_fully_connected(self, op, ifm, filter, bias, qinfo):
1192 result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
1193
Kevin Cheng550ccc52021-03-03 11:21:43 -08001194 self.ser.addOperator(
1195 op, [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
1196 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001197 return result_tens
1198
1199 def build_matmul(self, op, a, b, qinfo):
1200 result_tens = OutputShaper.matmulOp(self.ser, a, b)
1201 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], None, qinfo)
1202 return result_tens
1203
1204 def build_reduce(self, op, a, axis):
1205 result_tens = OutputShaper.reduceOp(self.ser, a, axis)
1206
1207 attr = ts.TosaSerializerAttribute()
1208 attr.AxisAttribute(axis)
1209
1210 self.ser.addOperator(op, [a.name], result_tens.name, attr)
1211 return result_tens
1212
1213 def build_clamp(self, op, a):
1214 result_tens = OutputShaper.unaryOp(self.ser, a)
1215
1216 attr = ts.TosaSerializerAttribute()
Jeremy Johnson18e26662021-07-22 16:15:29 +01001217 v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
Eric Kunzee5e26762020-10-13 16:11:07 -07001218
1219 if a.dtype == DType.FLOAT:
1220 attr.ClampAttribute(0, 0, min(v), max(v))
1221 else:
1222 attr.ClampAttribute(min(v), max(v), 0, 0)
1223
1224 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1225 return result_tens
1226
1227 def build_leaky_relu(self, op, a):
1228 result_tens = OutputShaper.unaryOp(self.ser, a)
1229 attr = ts.TosaSerializerAttribute()
1230
1231 attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
1232
1233 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1234 return result_tens
1235
1236 # Needs an additional type/input
1237 def build_prelu(self, op, a):
1238 result_tens = OutputShaper.unaryOp(self.ser, a)
1239
1240 self.ser.addOperator(op, [a.name], [result_tens.name])
1241 return result_tens
1242
1243 def build_relun(self, op, a):
1244 result_tens = OutputShaper.unaryOp(self.ser, a)
1245
1246 attr = ts.TosaSerializerAttribute()
1247
1248 if a.dtype == DType.FLOAT:
1249 attr.ReluNAttribute(0, self.getRandNumberDType(a.dtype))
1250 else:
1251 attr.ReluNAttribute(self.getRandNumberDType(a.dtype), 0)
1252
1253 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1254 return result_tens
1255
1256 def build_sigmoid(self, op, a):
1257 result_tens = OutputShaper.unaryOp(self.ser, a)
1258 self.ser.addOperator(op, [a.name], [result_tens.name])
1259 return result_tens
1260
1261 def build_tanh(self, op, a):
1262 result_tens = OutputShaper.unaryOp(self.ser, a)
1263 self.ser.addOperator(op, [a.name], [result_tens.name])
1264 return result_tens
1265
1266 def build_concat(self, op, a, b, axis):
1267 result_tens = OutputShaper.concatOp(self.ser, a, b, axis)
1268
1269 attr = ts.TosaSerializerAttribute()
1270 attr.AxisAttribute(axis)
1271
1272 self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
1273
1274 def build_pad(self, op, a, padding, qinfo):
1275 result_tens = OutputShaper.padOp(self.ser, a, padding)
1276
1277 # Need to turn the padding array into a TOSA tensor here.
1278 # This is one of the few tensor operands that does not get
1279 # randomly generated
Kevin Cheng550ccc52021-03-03 11:21:43 -08001280 padding_tens = self.ser.addConst(padding.shape, DType.INT32, padding)
Eric Kunzee5e26762020-10-13 16:11:07 -07001281
Kevin Cheng550ccc52021-03-03 11:21:43 -08001282 self.ser.addOperator(
1283 op, [a.name, padding_tens.name], [result_tens.name], None, qinfo
1284 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001285
1286 def build_reshape(self, op, a, newShape):
1287 result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
1288
1289 attr = ts.TosaSerializerAttribute()
1290 attr.ReshapeAttribute(newShape)
1291
1292 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1293 return result_tens
1294
1295 def build_reverse(self, op, a, axis):
1296 result_tens = OutputShaper.unaryOp(self.ser, a)
1297
1298 attr = ts.TosaSerializerAttribute()
1299 attr.AxisAttribute(axis)
1300
1301 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1302 return result_tens
1303
1304 def build_transpose(self, op, a, perms):
1305 result_tens = OutputShaper.transposeOp(self.ser, a, perms)
1306
Kevin Cheng550ccc52021-03-03 11:21:43 -08001307 perms_tens = self.ser.addConst([len(perms)], DType.INT32, np.int32(perms))
Eric Kunzee5e26762020-10-13 16:11:07 -07001308
1309 self.ser.addOperator(op, [a.name, perms_tens.name], [result_tens.name])
1310 return result_tens
1311
1312 def build_slice(self, op, a, begin, size):
1313 result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
1314
1315 attr = ts.TosaSerializerAttribute()
1316 attr.SliceAttribute(begin, size)
1317
1318 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1319 return result_tens
1320
1321 def build_tile(self, op, a, multiples):
1322 result_tens = OutputShaper.tileOp(self.ser, a, multiples)
1323
1324 attr = ts.TosaSerializerAttribute()
1325 attr.TileAttribute(multiples)
1326
1327 self.ser.addOperator(op, [a.name], [result_tens.name], attr)
1328 return result_tens
1329
Kevin Cheng77d0f762020-11-24 10:26:32 -08001330 def build_gather(self, op, values):
Eric Kunzee5e26762020-10-13 16:11:07 -07001331
1332 # Create a new indicies tensor
1333 # here with data that doesn't exceed the dimensions of the values tensor
1334
Kevin Cheng550ccc52021-03-03 11:21:43 -08001335 K = values.shape[1] # K
1336 W = self.randInt(
1337 self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]
1338 ) # W
1339 indicies_arr = np.int32(
1340 self.rng.integers(low=0, high=K, size=[values.shape[0], W])
1341 ) # (N, W)
1342 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001343
Kevin Cheng77d0f762020-11-24 10:26:32 -08001344 result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
Eric Kunzee5e26762020-10-13 16:11:07 -07001345
Kevin Cheng77d0f762020-11-24 10:26:32 -08001346 self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001347
1348 return result_tens
1349
Kevin Cheng77d0f762020-11-24 10:26:32 -08001350 def build_scatter(self, op, values_in, input):
1351
1352 # Create a new indicies tensor
1353 # here with data that doesn't exceed the dimensions of the values_in tensor
1354
Kevin Cheng550ccc52021-03-03 11:21:43 -08001355 K = values_in.shape[1] # K
1356 W = input.shape[1] # W
1357 indicies_arr = np.int32(
1358 self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
1359 ) # (N, W)
1360 indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
Kevin Cheng77d0f762020-11-24 10:26:32 -08001361
1362 result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
1363
Kevin Cheng550ccc52021-03-03 11:21:43 -08001364 self.ser.addOperator(
1365 op, [values_in.name, indicies.name, input.name], [result_tens.name]
1366 )
Kevin Cheng77d0f762020-11-24 10:26:32 -08001367
1368 return result_tens
1369
Kevin Cheng550ccc52021-03-03 11:21:43 -08001370 def build_resize(
1371 self,
1372 op,
1373 input,
1374 mode,
1375 stride,
1376 offset,
1377 shift,
1378 stride_fp,
1379 offset_fp,
1380 output_dims,
1381 input_dtype,
1382 output_dtype,
1383 ):
1384 result_tens = OutputShaper.resizeOp(
1385 self.ser,
1386 input,
1387 mode,
1388 stride,
1389 offset,
1390 shift,
1391 stride_fp,
1392 offset_fp,
1393 output_dims,
1394 input_dtype,
1395 output_dtype,
1396 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001397
1398 attr = ts.TosaSerializerAttribute()
Kevin Cheng77d0f762020-11-24 10:26:32 -08001399
Kevin Cheng550ccc52021-03-03 11:21:43 -08001400 attr.ResizeAttribute(
1401 output_dims, stride, offset, shift, stride_fp, offset_fp, mode
1402 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001403
1404 self.ser.addOperator(op, [input.name], [result_tens.name], attr)
1405 return result_tens
1406
1407 def build_identityn(self, op, val, val2):
1408
Kevin Cheng550ccc52021-03-03 11:21:43 -08001409 result_tens = OutputShaper.unaryOp(self.ser, val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001410 result_tens2 = OutputShaper.unaryOp(self.ser, val2)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001411 self.ser.addOperator(
1412 op, [val.name, val2.name], [result_tens.name, result_tens2.name]
1413 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001414 return result_tens
1415
1416 def build_placeholder(self, op, val):
1417 # Add an identity op to avoid warning in the reference model
1418 return self.build_unary(Op.IDENTITY, val)
1419
1420 # Type Conversion
1421 def build_cast(self, op, val, out_dtype):
1422 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1423 self.ser.addOperator(op, [val.name], [result_tens.name])
1424 return result_tens
1425
1426 def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
1427 result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
1428
1429 if per_channel:
1430 nc = val.shape[-1]
1431 else:
1432 nc = 1
1433
1434 in_type_width = self.typeWidth(val.dtype)
1435 out_type_width = self.typeWidth(out_dtype)
1436
Kevin Cheng3a478572021-01-22 17:21:02 -08001437 if val.dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001438 input_zp = self.randInt(-128, 128)
1439 in_type_width = in_type_width + 1
1440 elif val.dtype == DType.UINT8:
1441 input_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07001442 in_type_width = in_type_width + 1
1443 else:
1444 input_zp = 0
1445
Kevin Cheng3a478572021-01-22 17:21:02 -08001446 if out_dtype == DType.INT8:
Matthew Haddoncac4ee92021-07-22 14:30:53 +01001447 output_zp = self.randInt(-128, 128)
1448 out_type_width = out_type_width + 1
1449 elif out_dtype == DType.UINT8:
1450 output_zp = self.randInt(0, 256)
Eric Kunzee5e26762020-10-13 16:11:07 -07001451 out_type_width = out_type_width + 1
1452 else:
1453 output_zp = 0
1454
1455 # Calculate scale based on:
1456 # scale = a *(2^output_width)/(2^input_width))
1457
1458 a = np.float32(self.rng.random(size=[nc]))
1459 scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
1460
1461 if scale32:
1462 pass
1463 # Cap the scaling at 2^15 - 1 for scale16
1464 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
1465 else:
1466 # Cap the scaling at 2^15 - 1 for scale16
1467 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
1468
Kevin Cheng550ccc52021-03-03 11:21:43 -08001469 # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
Eric Kunzee5e26762020-10-13 16:11:07 -07001470
1471 multiplier_arr = np.int32(np.zeros(shape=[nc]))
1472 shift_arr = np.int32(np.zeros(shape=[nc]))
1473
1474 for i in range(nc):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001475 multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
1476 scale_arr[i], scale32
1477 )
Kevin Chengaee1fac2020-11-11 13:54:06 -08001478 if shift_arr[i] < 2 or shift_arr[i] > 62:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001479 self.ser.setExpectedFailure(True, "OpRescale: invalid shift value")
Eric Kunzee5e26762020-10-13 16:11:07 -07001480
Kevin Cheng550ccc52021-03-03 11:21:43 -08001481 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
Eric Kunzee5e26762020-10-13 16:11:07 -07001482
1483 attr = ts.TosaSerializerAttribute()
Kevin Cheng550ccc52021-03-03 11:21:43 -08001484 attr.RescaleAttribute(
1485 input_zp,
1486 output_zp,
1487 multiplier_arr,
1488 shift_arr,
1489 scale32,
1490 double_round,
1491 per_channel,
1492 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001493
1494 self.ser.addOperator(op, [val.name], [result_tens.name], attr)
1495 return result_tens
1496
1497 def build_cond_if_const(self, op, then_tens, else_tens, cond):
1498 # For cond_if with constants, we're supplied with then/else tensors that we ignore
1499 # (except for the generated shap) and the condition. Build Then/Else blocks
1500 # and fill them with const nodes for the body.
1501
1502 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001503 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001504
1505 # Make then/else tensors
1506 out_shape = then_tens.shape
Jeremy Johnson18e26662021-07-22 16:15:29 +01001507 then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
1508 else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
Eric Kunzee5e26762020-10-13 16:11:07 -07001509
1510 # And the result tensor based on any of the outputs
Kevin Cheng550ccc52021-03-03 11:21:43 -08001511 result_tens = self.ser.addOutput(out_shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07001512
1513 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001514 then_block = "THEN_BLOCK"
1515 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001516 attr = ts.TosaSerializerAttribute()
1517 attr.CondIfAttribute(then_block, else_block)
1518
1519 # Finally, build the op and the two blocks
1520 self.ser.addOperator(op, [cond_tens.name], [result_tens.name], attr)
1521
1522 self.ser.startBasicBlock(then_block)
1523 # Build the actual then/else tensors inside their blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001524 then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001525 self.ser.addOutputTensor(then_tens)
1526
1527 self.ser.startBasicBlock(else_block)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001528 else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001529 self.ser.addOutputTensor(else_tens)
1530
1531 return result_tens
1532
1533 def build_cond_if_binary(self, op, a, b, cond):
1534 # For cond_if with a binary op in the then/else blocks, take a and b and
1535 # alternately add or subtract them based on the condition
1536
1537 # Condition tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001538 cond_tens = self.ser.addConst([], DType.BOOL, [cond])
Eric Kunzee5e26762020-10-13 16:11:07 -07001539
Kevin Cheng550ccc52021-03-03 11:21:43 -08001540 result_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001541 self.ser.currBasicBlock.addOutput(result_tens.name)
1542
1543 # Create the attribute with the names of the then/else blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001544 then_block = "THEN_BLOCK"
1545 else_block = "ELSE_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001546 attr = ts.TosaSerializerAttribute()
1547 attr.CondIfAttribute(then_block, else_block)
1548
1549 # Finally, build the op and the two blocks
Kevin Cheng550ccc52021-03-03 11:21:43 -08001550 self.ser.addOperator(
1551 op, [cond_tens.name, a.name, b.name], [result_tens.name], attr
1552 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001553
1554 self.ser.startBasicBlock(then_block)
1555 self.ser.addInputTensor(a)
1556 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001557 then_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001558 self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
1559
1560 self.ser.startBasicBlock(else_block)
1561 self.ser.addInputTensor(a)
1562 self.ser.addInputTensor(b)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001563 else_tens = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001564 self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
1565
1566 return result_tens
1567
1568 def build_while_loop(self, op, a, iter_val):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001569 iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
Eric Kunzee5e26762020-10-13 16:11:07 -07001570
Kevin Cheng550ccc52021-03-03 11:21:43 -08001571 cond_block = "COND_BLOCK"
1572 body_block = "BODY_BLOCK"
Eric Kunzee5e26762020-10-13 16:11:07 -07001573
1574 attr = ts.TosaSerializerAttribute()
1575 attr.WhileLoopAttribute(cond_block, body_block)
1576
1577 # Accumulator tensor
Kevin Cheng550ccc52021-03-03 11:21:43 -08001578 # acc = self.ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001579 acc_init_val = np.int32(np.zeros(a.shape))
Kevin Cheng550ccc52021-03-03 11:21:43 -08001580 acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val)
Eric Kunzee5e26762020-10-13 16:11:07 -07001581
1582 # Intermediate/output tensors for everything going through the loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08001583 iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1584 a_out = self.ser.addIntermediate(a.shape, a.dtype)
1585 acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001586
1587 # While_loop operator
Kevin Cheng550ccc52021-03-03 11:21:43 -08001588 self.ser.addOperator(
1589 op,
1590 [iter.name, a.name, acc.name],
1591 [iter_out.name, a_out.name, acc_out.name],
1592 attr,
1593 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001594
1595 # COND block (input: iter, output: cond_tens )
1596 self.ser.startBasicBlock(cond_block)
1597 self.ser.addInputTensor(iter)
1598 self.ser.addInputTensor(a)
1599 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001600 zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
1601 cond_tens = self.ser.addOutput([], DType.BOOL)
1602 self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
Eric Kunzee5e26762020-10-13 16:11:07 -07001603
1604 # BODY block (input: a, acc, iter, output: a, acc, iter)
1605 # Note that local intermediate tensors need to be declared here for the outputs
1606 self.ser.startBasicBlock(body_block)
1607 self.ser.addInputTensor(iter)
1608 self.ser.addInputTensor(a)
1609 self.ser.addInputTensor(acc)
Kevin Cheng550ccc52021-03-03 11:21:43 -08001610 one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
1611 iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
1612 acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07001613 self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
1614 self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
1615 self.ser.addOutputTensor(iter_body_out)
1616 self.ser.addOutputTensor(a)
1617 self.ser.addOutputTensor(acc_body_out)
1618
1619 return acc_out
1620
Kevin Cheng550ccc52021-03-03 11:21:43 -08001621 def genOpTestList(
1622 self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None
1623 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07001624
1625 try:
1626 op = self.TOSA_OP_LIST[opName]
1627 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001628 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001629
1630 # Initialize a new random number generator
1631 self.rng = np.random.default_rng(self.random_seed)
1632
Kevin Cheng550ccc52021-03-03 11:21:43 -08001633 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001634
1635 # Generate the lists of arguments
Kevin Cheng550ccc52021-03-03 11:21:43 -08001636 rmin, rmax = op["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001637
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001638 # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
1639 default_test_rank_range = range(1, 5)
1640
Eric Kunzee5e26762020-10-13 16:11:07 -07001641 # Test list consists of a tuple of:
1642 # (opName, testNameStr, dtype, shapeList, argumentsList)
1643 testList = []
1644
1645 if not shapeFilter:
1646 shapeFilter = [None]
1647
1648 for r in range(rmin, rmax + 1):
1649
1650 # Filter out the rank?
1651 if rankFilter is not None and r not in rankFilter:
1652 continue
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001653 if rankFilter is None and shapeFilter[0] is None and r not in default_test_rank_range:
1654 continue
Eric Kunzee5e26762020-10-13 16:11:07 -07001655
Kevin Cheng550ccc52021-03-03 11:21:43 -08001656 for t in op["types"]:
Eric Kunzee5e26762020-10-13 16:11:07 -07001657
1658 # Filter tests based on dtype?
1659 if dtypeFilter is not None:
Les Bell30e46802021-07-23 09:43:31 +01001660 if not (t in dtypeFilter or (isinstance(t, list) and t[0] in dtypeFilter)):
Eric Kunzee5e26762020-10-13 16:11:07 -07001661 continue
1662
1663 # Create the placeholder and const tensors
1664 for shape in shapeFilter:
1665 # A None shape chooses a random shape of a given rank
1666
1667 # Filter out by rank
1668 if shape is not None and len(shape) != r:
1669 continue
1670
1671 self.setTargetShape(shape)
1672 shapeList = tgen_fcn(self, op, r)
1673
1674 shapeStr = self.shapeStr(shapeList[0])
1675 typeStr = self.typeStr(t)
1676
1677 # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
1678 argList = []
1679 if agen_fcn:
1680 argList = agen_fcn(self, opName, shapeList, t)
1681 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001682 argList = [("", [])]
Eric Kunzee5e26762020-10-13 16:11:07 -07001683
1684 for argStr, args in argList:
1685 if argStr:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001686 testStr = "{}_{}_{}_{}".format(
1687 opName, shapeStr, typeStr, argStr
1688 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001689 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001690 testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
Eric Kunzee5e26762020-10-13 16:11:07 -07001691
1692 testList.append((opName, testStr, t, shapeList, args))
1693
1694 return testList
1695
Kevin Cheng989cb052021-04-28 16:29:44 -07001696 def serializeTest(self, opName, testStr, dtype_or_dtypeList, shapeList, testArgs):
Eric Kunzee5e26762020-10-13 16:11:07 -07001697 try:
1698 op = self.TOSA_OP_LIST[opName]
1699 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001700 raise Exception("Cannot find op with name {}".format(opName))
Eric Kunzee5e26762020-10-13 16:11:07 -07001701
1702 # Create a serializer
1703 self.createSerializer(opName, testStr)
1704
Kevin Cheng550ccc52021-03-03 11:21:43 -08001705 build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
1706 pCount, cCount = op["operands"]
Kevin Cheng989cb052021-04-28 16:29:44 -07001707 num_operands = pCount + cCount
1708
1709 if isinstance(dtype_or_dtypeList, list):
1710 dtypeList = dtype_or_dtypeList
1711 else:
1712 dtypeList = [dtype_or_dtypeList] * (num_operands)
1713
1714 assert (
1715 len(shapeList) == num_operands
1716 ), "shapeList length {} must match number of operands {}".format(
1717 len(shapeList), num_operands
1718 )
1719 assert (
1720 len(dtypeList) == num_operands
1721 ), "dtypeList length {} must match number of operands {}".format(
1722 len(dtypeList), num_operands
1723 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001724
1725 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001726 qgen = op["qgen"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001727 except KeyError:
1728 qgen = None
1729
1730 # Build the random tensor operands and the test
1731 tens = []
Kevin Chengaee1fac2020-11-11 13:54:06 -08001732
1733 # If test is ArithmeticRightShift, force value of operand[1] to be within [0, num_bits]
Kevin Cheng550ccc52021-03-03 11:21:43 -08001734 if op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
1735 assert (
1736 pCount == 2 and cCount == 0
1737 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
Kevin Chengaee1fac2020-11-11 13:54:06 -08001738
1739 placeholders = []
1740 for idx, shape in enumerate(shapeList[:]):
1741 if idx == 1:
Kevin Cheng989cb052021-04-28 16:29:44 -07001742 if dtypeList[idx] == DType.INT8:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001743 arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001744 elif dtypeList[idx] == DType.INT16:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001745 arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
Kevin Cheng989cb052021-04-28 16:29:44 -07001746 elif dtypeList[idx] == DType.INT32:
Kevin Chengaee1fac2020-11-11 13:54:06 -08001747 arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
1748 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001749 raise Exception("OpArithmeticRightShift: invalid input dtype")
Kevin Chengaee1fac2020-11-11 13:54:06 -08001750 else:
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001751 arr = self.getRandTensor(shape, dtypeList[idx])
Kevin Cheng989cb052021-04-28 16:29:44 -07001752 placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
Kevin Chengaee1fac2020-11-11 13:54:06 -08001753
1754 tens.extend(placeholders)
Matthew Haddona44ac5e2021-07-27 16:31:16 +01001755 elif op["op"] == Op.SELECT:
1756 # Set datatype of condition tensor to boolean
1757 dtypeList[0] = DType.BOOL
1758 tens.extend(
1759 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
1760 )
1761 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001762 elif op["op"] == Op.DIV:
1763 assert (
1764 pCount == 2 and cCount == 0
1765 ), "Op.Div must have 2 placeholders, 0 consts"
1766
1767 placeholders = []
1768
1769 # Two invalid cases for Op.DIV:
1770 # 1. divisor == 0
Kevin Cheng47315e12021-05-13 17:41:28 -07001771 # 2. dividend == -(1<<31) and divisor == -1
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001772 while True:
1773 dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
1774 divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
1775
1776 if (divisor_arr == 0).any():
1777 continue
1778
Kevin Cheng47315e12021-05-13 17:41:28 -07001779 if (dividend_arr == -(2 ** 31)).any() and (divisor_arr == -1).any():
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07001780 continue
1781
1782 break
1783
1784 placeholders.append(
1785 self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1786 )
1787 placeholders.append(
1788 self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1789 )
1790
1791 tens.extend(placeholders)
1792 elif op["op"] == Op.MUL:
1793 assert (
1794 pCount == 2 and cCount == 0
1795 ), "Op.MUL must have 2 placeholders, 0 consts"
1796
1797 if dtypeList[0] == DType.FLOAT:
1798 tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
1799 else:
1800 placeholders = []
1801
1802 # Make sure multiply result in int32 range
1803 shift = testArgs[0]
1804 if dtypeList[0] == DType.INT8:
1805 num_bits = 8
1806 elif dtypeList[0] == DType.INT16:
1807 num_bits = 16
1808 elif dtypeList[0] == DType.INT32:
1809 num_bits = 32
1810 else:
1811 raise Exception("OpMul: invalid input dtype")
1812
1813 for idx, shape in enumerate(shapeList[:]):
1814 low = -(2 ** (num_bits - 1))
1815 high = (2 ** (num_bits - 1)) - 1
1816
1817 a_arr = np.int32(
1818 self.rng.integers(low=low, high=high, size=shapeList[0])
1819 )
1820 b_arr = np.int32(
1821 self.rng.integers(low=low, high=high, size=shapeList[1])
1822 )
1823
1824 i = 0
1825 while True:
1826
1827 a_arr_64 = a_arr.astype(np.int64)
1828 b_arr_64 = b_arr.astype(np.int64)
1829
1830 if shift > 0:
1831 rounding = 1 << (shift - 1)
1832 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
1833 else:
1834 result_arr = a_arr_64 * b_arr_64
1835
1836 if (result_arr > -(2 ** 31)).all() and (
1837 result_arr <= ((2 ** 31) - 1)
1838 ).all():
1839 break
1840
1841 i = i + 1
1842 a_arr = a_arr // 2
1843 b_arr = b_arr // 2
1844
1845 placeholders.append(
1846 self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1847 )
1848 placeholders.append(
1849 self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1850 )
1851
1852 tens.extend(placeholders)
Kevin Chengaee1fac2020-11-11 13:54:06 -08001853 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07001854 tens.extend(
1855 self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
1856 )
1857 tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
Eric Kunzee5e26762020-10-13 16:11:07 -07001858
1859 if qgen is not None:
Les Bell30e46802021-07-23 09:43:31 +01001860 qinfo = qgen(self, op, dtype_or_dtypeList)
Eric Kunzee5e26762020-10-13 16:11:07 -07001861 else:
1862 qinfo = None
1863
1864 try:
1865 if qinfo is not None:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001866 resultName = build_fcn(self, op["op"], *tens, *testArgs, qinfo)
Eric Kunzee5e26762020-10-13 16:11:07 -07001867 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001868 resultName = build_fcn(self, op["op"], *tens, *testArgs)
Eric Kunzee5e26762020-10-13 16:11:07 -07001869 except TypeError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001870 print(
1871 "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
1872 build_fcn, tens, testArgs
1873 )
1874 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001875 raise e
1876
1877 # Save the serialized test
Kevin Cheng550ccc52021-03-03 11:21:43 -08001878 self.serialize("test")
Eric Kunzee5e26762020-10-13 16:11:07 -07001879
1880 def createDynamicOpLists(self):
1881
1882 # Dynamically create op lists for convolutions with a list of kernel sizes
Kevin Cheng550ccc52021-03-03 11:21:43 -08001883 KERNELS = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
Eric Kunzee5e26762020-10-13 16:11:07 -07001884
1885 for k in KERNELS:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001886 testName = "conv2d_{}x{}".format(k[0], k[1])
1887 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
1888 self.TOSA_OP_LIST[testName]["filter"] = k
1889 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001890
Kevin Cheng550ccc52021-03-03 11:21:43 -08001891 testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
1892 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1893 "depthwise_conv2d_TEMPLATE"
1894 ].copy()
1895 self.TOSA_OP_LIST[testName]["filter"] = k
1896 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001897
Kevin Cheng550ccc52021-03-03 11:21:43 -08001898 testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
1899 self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
1900 "transpose_conv2d_TEMPLATE"
1901 ].copy()
1902 self.TOSA_OP_LIST[testName]["filter"] = k
1903 self.TOSA_OP_LIST[testName]["template"] = False
Eric Kunzee5e26762020-10-13 16:11:07 -07001904
1905 # Delete any templates after having created any dynamic ops
1906 # This is a two-pass operation because it's bad practice to delete
1907 # keys from dictionaries while iterating
1908 keyList = []
1909 for k in self.TOSA_OP_LIST:
1910 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001911 if self.TOSA_OP_LIST[k]["template"] == True:
Eric Kunzee5e26762020-10-13 16:11:07 -07001912 keyList.append(k)
1913 continue
1914 except KeyError:
1915 pass
1916
1917 for k in keyList:
1918 del self.TOSA_OP_LIST[k]
1919
1920 def initOpListDefaults(self):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001921 """Fill in default fields for ops if they aren't already specified.
1922 Look for missing required fields (datastructure linting)."""
Eric Kunzee5e26762020-10-13 16:11:07 -07001923 for op in self.TOSA_OP_LIST:
1924
1925 # Required fields
1926 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001927 pl, c = self.TOSA_OP_LIST[op]["operands"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001928 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001929 raise Exception(
1930 "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op)
1931 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001932
1933 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001934 fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001935 except (KeyError, ValueError, TypeError):
Kevin Cheng550ccc52021-03-03 11:21:43 -08001936 raise Exception(
1937 "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
1938 op
1939 )
1940 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001941
1942 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001943 types = self.TOSA_OP_LIST[op]["types"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001944 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001945 raise Exception(
1946 "Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
1947 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001948
1949 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001950 opcode = self.TOSA_OP_LIST[op]["op"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001951 except KeyError as e:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001952 raise Exception(
1953 "Op {} is missing the Op field in TOSA_OP_LIST".format(op)
1954 )
Eric Kunzee5e26762020-10-13 16:11:07 -07001955
1956 # Put in default rank range, if missing
1957 try:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001958 rank = self.TOSA_OP_LIST[op]["rank"]
Eric Kunzee5e26762020-10-13 16:11:07 -07001959 except KeyError:
Kevin Cheng550ccc52021-03-03 11:21:43 -08001960 self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
Eric Kunzee5e26762020-10-13 16:11:07 -07001961
1962 # Tensor operator list
1963 # 'op': op name
1964 # 'operands': tuple of (placeholder, const) operands
Kevin Cheng3a478572021-01-22 17:21:02 -08001965 # 'rank': optional, restricts rank to tuple inclusive of (min, max),
1966 # if not specified, defaults to (1, 4)
Eric Kunzee5e26762020-10-13 16:11:07 -07001967 # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
1968 # 'types': array of datatypes to be tested
Kevin Cheng550ccc52021-03-03 11:21:43 -08001969 TYPE_FP = [DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07001970
Kevin Cheng550ccc52021-03-03 11:21:43 -08001971 TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
1972 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
Eric Kunzee5e26762020-10-13 16:11:07 -07001973
Kevin Cheng550ccc52021-03-03 11:21:43 -08001974 TYPE_BOOL = [DType.BOOL]
1975 TYPE_FI32 = [DType.FLOAT, DType.INT32]
1976 TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
1977 TYPE_FI16 = [DType.FLOAT, DType.INT16]
Eric Kunzee5e26762020-10-13 16:11:07 -07001978
Kevin Cheng550ccc52021-03-03 11:21:43 -08001979 TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
Eric Kunzee5e26762020-10-13 16:11:07 -07001980
Kevin Cheng989cb052021-04-28 16:29:44 -07001981 TYPE_CONV2D = [
Kevin Chenga9017402021-07-28 17:19:23 -07001982 [DType.INT8, DType.INT4, DType.INT32],
Kevin Cheng989cb052021-04-28 16:29:44 -07001983 [DType.INT8, DType.INT8, DType.INT32],
1984 [DType.INT16, DType.INT8, DType.INT48],
1985 DType.FLOAT,
1986 ]
1987
Jeremy Johnson97eb75f2021-07-08 11:58:02 +01001988 DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
Eric Kunzee5e26762020-10-13 16:11:07 -07001989
1990 TOSA_OP_LIST = {
Jared Smolens573ecd42021-03-04 15:24:10 -08001991 # Tensor operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08001992 "argmax": {
1993 "op": Op.ARGMAX,
1994 "operands": (1, 0),
1995 "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
1996 "types": TYPE_NARROW_INT_FP,
1997 },
Jared Smolens573ecd42021-03-04 15:24:10 -08001998 "avg_pool2d": {
1999 "op": Op.AVG_POOL2D,
2000 "operands": (1, 0),
2001 "rank": (4, 4),
2002 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2003 "qgen": TosaQuantGen.qgUnary,
2004 "types": TYPE_NARROW_INT_FP,
2005 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002006 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002007 "conv2d_TEMPLATE": {
2008 "op": Op.CONV2D,
2009 "operands": (1, 2),
2010 "rank": (4, 4),
2011 "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
2012 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002013 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002014 "template": True,
2015 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002016 # Conv3d TBD
Eric Kunzee5e26762020-10-13 16:11:07 -07002017 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002018 "depthwise_conv2d_TEMPLATE": {
2019 "op": Op.DEPTHWISE_CONV2D,
2020 "operands": (1, 2),
2021 "filter": [1, 1],
2022 "rank": (4, 4),
2023 "build_fcn": (
2024 build_depthwise_conv2d,
2025 TosaTensorGen.tgDepthwiseConv2D,
2026 TosaArgGen.agConv2D,
2027 ),
2028 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002029 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002030 "template": True,
2031 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002032 "fully_connected": {
2033 "op": Op.FULLY_CONNECTED,
2034 "operands": (1, 2),
2035 "rank": (2, 2),
2036 "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
2037 "qgen": TosaQuantGen.qgConv,
2038 "types": TYPE_CONV2D,
2039 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002040 "matmul": {
2041 "op": Op.MATMUL,
2042 "operands": (2, 0),
Kevin Cheng2d60f002021-06-09 14:18:32 -07002043 "rank": (3, 3),
Jared Smolens573ecd42021-03-04 15:24:10 -08002044 "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
2045 "qgen": TosaQuantGen.qgMatmul,
2046 "types": TYPE_NARROW_INT_FP,
2047 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002048 "max_pool2d": {
2049 "op": Op.MAX_POOL2D,
2050 "operands": (1, 0),
2051 "rank": (4, 4),
2052 "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
2053 "types": TYPE_NARROW_INT_FP,
2054 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002055 # Templated operator. Filled in by createDynamicOpLists
Kevin Cheng550ccc52021-03-03 11:21:43 -08002056 "transpose_conv2d_TEMPLATE": {
2057 "op": Op.TRANSPOSE_CONV2D,
Kevin Cheng989cb052021-04-28 16:29:44 -07002058 "operands": (1, 2),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002059 "rank": (4, 4),
2060 "build_fcn": (
2061 build_transpose_conv2d,
2062 TosaTensorGen.tgTransposeConv2D,
2063 TosaArgGen.agTransposeConv2D,
2064 ),
2065 "qgen": TosaQuantGen.qgConv,
Kevin Cheng989cb052021-04-28 16:29:44 -07002066 "types": TYPE_CONV2D,
Kevin Cheng550ccc52021-03-03 11:21:43 -08002067 "template": True,
2068 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002069 # Activation functions
Kevin Cheng550ccc52021-03-03 11:21:43 -08002070 "clamp": {
2071 "op": Op.CLAMP,
2072 "operands": (1, 0),
2073 "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
2074 "types": TYPE_NARROW_INT_FP,
2075 },
2076 "relun": {
2077 "op": Op.RELUN,
2078 "operands": (1, 0),
2079 "build_fcn": (build_relun, TosaTensorGen.tgBasic, None),
2080 "types": TYPE_FI32,
2081 },
2082 "sigmoid": {
2083 "op": Op.SIGMOID,
2084 "operands": (1, 0),
2085 "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
2086 "types": TYPE_FP,
2087 },
2088 "tanh": {
2089 "op": Op.TANH,
2090 "operands": (1, 0),
2091 "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
2092 "types": TYPE_FP,
2093 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002094 # Elementwise Binary Operators
2095 "add": {
2096 "op": Op.ADD,
2097 "operands": (2, 0),
2098 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2099 "types": TYPE_FI32,
2100 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002101 "arithmetic_right_shift": {
2102 "op": Op.ARITHMETIC_RIGHT_SHIFT,
2103 "operands": (2, 0),
2104 "build_fcn": (
2105 build_arithmetic_right_shift,
2106 TosaTensorGen.tgBroadcastFuzz,
2107 TosaArgGen.agArithmeticRightShift,
2108 ),
2109 "types": TYPE_INT,
2110 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002111 "bitwise_and": {
2112 "op": Op.BITWISE_AND,
2113 "operands": (2, 0),
2114 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2115 "types": TYPE_INT,
2116 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002117 "bitwise_or": {
2118 "op": Op.BITWISE_OR,
2119 "operands": (2, 0),
2120 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2121 "types": TYPE_INT,
2122 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002123 "bitwise_xor": {
2124 "op": Op.BITWISE_XOR,
2125 "operands": (2, 0),
2126 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2127 "types": TYPE_INT,
2128 },
Kevin Cheng14d7f7a2021-05-12 10:44:49 -07002129 "div": {
2130 "op": Op.DIV,
2131 "operands": (2, 0),
2132 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2133 "types": [DType.INT32],
2134 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002135 "logical_and": {
2136 "op": Op.LOGICAL_AND,
2137 "operands": (2, 0),
2138 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2139 "types": TYPE_BOOL,
2140 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002141 "logical_left_shift": {
2142 "op": Op.LOGICAL_LEFT_SHIFT,
2143 "operands": (2, 0),
2144 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2145 "types": TYPE_INT,
2146 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002147 "logical_right_shift": {
2148 "op": Op.LOGICAL_RIGHT_SHIFT,
2149 "operands": (2, 0),
2150 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2151 "types": TYPE_INT,
2152 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002153 "logical_or": {
2154 "op": Op.LOGICAL_OR,
2155 "operands": (2, 0),
2156 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2157 "types": TYPE_BOOL,
2158 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002159 "logical_xor": {
2160 "op": Op.LOGICAL_XOR,
2161 "operands": (2, 0),
2162 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2163 "types": TYPE_BOOL,
2164 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002165 "maximum": {
2166 "op": Op.MAXIMUM,
2167 "operands": (2, 0),
2168 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2169 "types": TYPE_FI32,
2170 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002171 "minimum": {
2172 "op": Op.MINIMUM,
2173 "operands": (2, 0),
2174 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2175 "types": TYPE_FI32,
2176 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002177 "mul": {
2178 "op": Op.MUL,
2179 "operands": (2, 0),
2180 "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
2181 "types": TYPE_INT_FP,
2182 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002183 "pow": {
2184 "op": Op.POW,
2185 "operands": (2, 0),
2186 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
2187 "types": TYPE_FP,
2188 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002189 "sub": {
2190 "op": Op.SUB,
2191 "operands": (2, 0),
2192 "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
2193 "types": TYPE_FI32,
2194 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002195 "table": {
2196 "op": Op.TABLE,
2197 # Use the automatic generation functions to create the input array
2198 # but create the table tensor in the build function, as it may be
2199 # a different type from the input
2200 "operands": (1, 0),
2201 "build_fcn": (build_table, TosaTensorGen.tgBasic, None),
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002202 "types": [DType.INT8, DType.INT16],
Jared Smolens573ecd42021-03-04 15:24:10 -08002203 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002204 # Elementwise Unary operators
2205 "abs": {
2206 "op": Op.ABS,
2207 "operands": (1, 0),
2208 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2209 "types": TYPE_FI32,
2210 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002211 "bitwise_not": {
2212 "op": Op.BITWISE_NOT,
2213 "operands": (1, 0),
2214 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2215 "types": TYPE_INT,
2216 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002217 "ceil": {
2218 "op": Op.CEIL,
2219 "operands": (1, 0),
2220 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2221 "types": TYPE_FP,
2222 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002223 "clz": {
2224 "op": Op.CLZ,
2225 "operands": (1, 0),
2226 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2227 "types": [DType.INT32],
2228 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002229 "exp": {
2230 "op": Op.EXP,
2231 "operands": (1, 0),
2232 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2233 "types": TYPE_FP,
2234 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002235 "floor": {
2236 "op": Op.FLOOR,
2237 "operands": (1, 0),
2238 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2239 "types": TYPE_FP,
2240 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002241 "log": {
2242 "op": Op.LOG,
2243 "operands": (1, 0),
2244 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2245 "types": TYPE_FP,
2246 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002247 "logical_not": {
2248 "op": Op.LOGICAL_NOT,
2249 "operands": (1, 0),
2250 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2251 "types": TYPE_BOOL,
2252 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002253 "negate": {
2254 "op": Op.NEGATE,
2255 "operands": (1, 0),
2256 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2257 "qgen": TosaQuantGen.qgUnary,
2258 "types": TYPE_INT_FP,
2259 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002260 "reciprocal": {
2261 "op": Op.RECIPROCAL,
2262 "operands": (1, 0),
2263 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2264 "types": TYPE_FP,
2265 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002266 "rsqrt": {
2267 "op": Op.RSQRT,
2268 "operands": (1, 0),
2269 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2270 "types": TYPE_FP,
2271 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002272 # Elementwise Ternary operators
2273 "select": {
2274 "op": Op.SELECT,
2275 "operands": (3, 0),
2276 "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
2277 "types": TYPE_FIB,
2278 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002279 # Comparison operators
2280 "equal": {
2281 "op": Op.EQUAL,
2282 "operands": (2, 0),
2283 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2284 "types": TYPE_FI32,
2285 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002286 "greater_equal": {
2287 "op": Op.GREATER_EQUAL,
2288 "operands": (2, 0),
2289 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2290 "types": TYPE_FI32,
2291 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002292 "greater": {
2293 "op": Op.GREATER,
2294 "operands": (2, 0),
2295 "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
2296 "types": TYPE_FI32,
2297 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002298 # Reduction operators
2299 "reduce_all": {
2300 "op": Op.REDUCE_ALL,
2301 "operands": (1, 0),
2302 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2303 "types": TYPE_BOOL,
2304 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002305 "reduce_any": {
2306 "op": Op.REDUCE_ANY,
2307 "operands": (1, 0),
2308 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2309 "types": TYPE_BOOL,
2310 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002311 "reduce_max": {
2312 "op": Op.REDUCE_MAX,
2313 "operands": (1, 0),
2314 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2315 "types": TYPE_INT_FP,
2316 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002317 "reduce_min": {
2318 "op": Op.REDUCE_MAX,
2319 "operands": (1, 0),
2320 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2321 "types": TYPE_INT_FP,
2322 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002323 "reduce_product": {
2324 "op": Op.REDUCE_PRODUCT,
2325 "operands": (1, 0),
2326 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2327 "types": TYPE_FP,
2328 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002329 "reduce_sum": {
2330 "op": Op.REDUCE_SUM,
2331 "operands": (1, 0),
2332 "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2333 "types": TYPE_FI32,
2334 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002335 # Data layout operators
Kevin Cheng550ccc52021-03-03 11:21:43 -08002336 "concat": {
2337 "op": Op.CONCAT,
2338 "operands": (2, 0),
2339 "build_fcn": (build_concat, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2340 "types": TYPE_FIB,
2341 },
2342 "pad": {
2343 "op": Op.PAD,
2344 "operands": (1, 0),
2345 "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
2346 "qgen": TosaQuantGen.qgPad,
2347 "types": TYPE_FIB,
2348 },
2349 "reshape": {
2350 "op": Op.RESHAPE,
2351 "operands": (1, 0),
2352 "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
2353 "types": TYPE_FIB,
2354 },
2355 "reverse": {
2356 "op": Op.REVERSE,
2357 "operands": (1, 0),
2358 "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
2359 "types": TYPE_FIB,
2360 },
2361 "slice": {
2362 "op": Op.SLICE,
2363 "operands": (1, 0),
2364 "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
2365 "types": TYPE_FIB,
2366 },
2367 "tile": {
2368 "op": Op.TILE,
2369 "operands": (1, 0),
2370 "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
2371 "types": TYPE_FIB,
2372 },
2373 "transpose": {
2374 "op": Op.TRANSPOSE,
2375 "operands": (1, 0),
Jeremy Johnsona6185572021-06-21 15:55:35 +01002376 "rank": (1, 4),
Kevin Cheng550ccc52021-03-03 11:21:43 -08002377 "build_fcn": (
2378 build_transpose,
2379 TosaTensorGen.tgBasic,
2380 TosaArgGen.agTranspose,
2381 ),
2382 "types": TYPE_FIB,
2383 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002384 # Data nodes
2385 "const": {
2386 "op": Op.CONST,
2387 "operands": (1, 0),
2388 "build_fcn": (build_placeholder, TosaTensorGen.tgBasic, None),
2389 "types": TYPE_FIB,
2390 },
Jared Smolens573ecd42021-03-04 15:24:10 -08002391 "identity": {
2392 "op": Op.IDENTITY,
2393 "operands": (1, 0),
2394 "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
2395 "types": TYPE_FIB,
2396 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002397 # Scatter/Gather
Kevin Cheng550ccc52021-03-03 11:21:43 -08002398 "gather": {
2399 "op": Op.GATHER,
2400 # Only specify 'values' tensor here. 'indices' is generated in op building stage
2401 "operands": (1, 0),
2402 "rank": (3, 3),
2403 "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
2404 "types": TYPE_INT_FP,
2405 },
2406 "scatter": {
2407 "op": Op.SCATTER,
2408 # Only specify 'values_in' tensor here.
2409 #'indices' and 'input' are generated in op building stage
2410 "operands": (2, 0),
2411 "rank": (3, 3),
2412 "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
2413 "types": TYPE_INT_FP,
2414 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002415 # Image operations
Kevin Cheng550ccc52021-03-03 11:21:43 -08002416 "resize": {
2417 "op": Op.RESIZE,
2418 "operands": (1, 0),
2419 "rank": (4, 4),
2420 "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
2421 "types": [DType.INT8, DType.INT16, DType.FLOAT],
2422 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002423 # Type conversion
Kevin Cheng550ccc52021-03-03 11:21:43 -08002424 "cast": {
2425 "op": Op.CAST,
2426 "operands": (1, 0),
2427 "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
2428 "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
2429 },
2430 "rescale": {
2431 "op": Op.RESCALE,
2432 "operands": (1, 0),
2433 "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
Matthew Haddoncac4ee92021-07-22 14:30:53 +01002434 "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
Kevin Cheng550ccc52021-03-03 11:21:43 -08002435 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002436 # Custom
2437 # Not implemented.
Jared Smolens573ecd42021-03-04 15:24:10 -08002438 # Control flow operators
Eric Kunzee5e26762020-10-13 16:11:07 -07002439 # Two varients of cond_if, one that generates one of two constant tensors (no
2440 # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
2441 # (two inputs to the basic blocks, one output)
Kevin Cheng550ccc52021-03-03 11:21:43 -08002442 "cond_if_const": {
2443 "op": Op.COND_IF,
2444 "operands": (0, 2),
2445 "build_fcn": (
2446 build_cond_if_const,
2447 TosaTensorGen.tgBasic,
2448 TosaArgGen.agCondIf,
2449 ),
2450 "types": [DType.BOOL],
2451 },
2452 "cond_if_binary": {
2453 "op": Op.COND_IF,
2454 "operands": (2, 0),
2455 "build_fcn": (
2456 build_cond_if_binary,
2457 TosaTensorGen.tgBasic,
2458 TosaArgGen.agCondIf,
2459 ),
2460 "types": TYPE_FI32,
2461 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002462 # while_loop
Kevin Cheng550ccc52021-03-03 11:21:43 -08002463 "while_loop": {
2464 "op": Op.WHILE_LOOP,
2465 "operands": (0, 1),
2466 "build_fcn": (
2467 build_while_loop,
2468 TosaTensorGen.tgBasic,
2469 TosaArgGen.agWhileLoop,
2470 ),
2471 "types": [DType.INT32],
2472 },
Eric Kunzee5e26762020-10-13 16:11:07 -07002473 }
2474
Kevin Cheng550ccc52021-03-03 11:21:43 -08002475
Eric Kunzee5e26762020-10-13 16:11:07 -07002476class OutputShaper:
2477 # Methods in this class compute the expected output shape and datatype
2478 # for common classes of operations
2479 def __init__(self):
2480 pass
2481
2482 # These methods return arguments that can be used for
2483 # creating a new output tensor
2484 @staticmethod
2485 def binaryBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002486 assert len(a.shape) == len(b.shape)
2487 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002488
2489 shape = []
2490 for i in range(len(a.shape)):
2491 if a.shape[i] == 1:
2492 shape.append(b.shape[i])
2493 else:
2494 shape.append(a.shape[i])
2495
Kevin Cheng550ccc52021-03-03 11:21:43 -08002496 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002497
2498 @staticmethod
2499 def binaryNonBroadcastOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002500 assert len(a.shape) == len(b.shape)
2501 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002502
2503 shape = []
2504 for i in range(len(a.shape)):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002505 assert a.shape[i] == b.shape[i]
Eric Kunzee5e26762020-10-13 16:11:07 -07002506 shape.append(a.shape[i])
2507
Kevin Cheng550ccc52021-03-03 11:21:43 -08002508 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002509
2510 @staticmethod
2511 def unaryOp(ser, a):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002512 return ser.addOutput(a.shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002513
2514 @staticmethod
2515 def selectOp(ser, cond, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002516 assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
2517 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002518
2519 shape = []
2520 for i in range(len(a.shape)):
2521 shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
2522
Kevin Cheng550ccc52021-03-03 11:21:43 -08002523 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002524
2525 @staticmethod
2526 def binaryComparisonOp(ser, a, b):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002527 assert len(a.shape) == len(b.shape)
2528 assert a.dtype == b.dtype
Eric Kunzee5e26762020-10-13 16:11:07 -07002529
2530 # Do broadcast
2531 shape = []
2532 for i in range(len(a.shape)):
2533 if a.shape[i] == 1:
2534 shape.append(b.shape[i])
2535 else:
2536 shape.append(a.shape[i])
2537
2538 # Force the output type to bool
Kevin Cheng550ccc52021-03-03 11:21:43 -08002539 return ser.addOutput(shape, DType.BOOL)
Eric Kunzee5e26762020-10-13 16:11:07 -07002540
2541 @staticmethod
2542 def reduceOp(ser, a, axis):
2543
2544 shape = a.shape.copy()
2545
2546 shape[axis] = 1
2547
Kevin Cheng550ccc52021-03-03 11:21:43 -08002548 return ser.addOutput(shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002549
2550 @staticmethod
2551 def argmaxOp(ser, a, axis):
2552 shape = a.shape.copy()
2553 del shape[axis]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002554 return ser.addOutput(shape, DType.INT32)
Eric Kunzee5e26762020-10-13 16:11:07 -07002555
2556 @staticmethod
2557 def conv2dOp(ser, ifm, filter, strides, padding, dilations):
2558
2559 # IFM: NHWC
2560 # Filter: OHWI
2561 # OFM: NHWC
2562
2563 if len(padding) == 2:
2564 # Expand padding to 4 parameters in the case of transpose_conv2d
2565 # From H,W to T,B,L,R
2566 padding = [padding[0], padding[0], padding[1], padding[1]]
2567
Kevin Cheng550ccc52021-03-03 11:21:43 -08002568 h = (
2569 ifm.shape[1]
2570 - filter.shape[1]
2571 - (filter.shape[1] - 1) * (dilations[0] - 1)
2572 + padding[0]
2573 + padding[1]
2574 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002575
Kevin Cheng550ccc52021-03-03 11:21:43 -08002576 w = (
2577 ifm.shape[2]
2578 - filter.shape[2]
2579 - (filter.shape[2] - 1) * (dilations[1] - 1)
2580 + padding[2]
2581 + padding[3]
2582 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002583
2584 if h <= 0 or w <= 0:
2585 # Invalid test parameters?
2586 h = 0
2587 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002588 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002589
2590 ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
2591
Kevin Cheng3a478572021-01-22 17:21:02 -08002592 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002593 out_dtype = DType.INT32
2594 elif ifm.dtype == DType.INT16:
2595 out_dtype = DType.INT48
2596 elif ifm.dtype == DType.FLOAT:
2597 out_dtype = DType.FLOAT
2598 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002599 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002600
Kevin Cheng550ccc52021-03-03 11:21:43 -08002601 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002602
2603 @staticmethod
2604 def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
2605 # IFM: NHWC
2606 # Filter: HWCM
2607 # OFM: NHW C*M
Kevin Cheng550ccc52021-03-03 11:21:43 -08002608 h = (
2609 ifm.shape[1]
2610 - filter.shape[0]
2611 - (filter.shape[0] - 1) * (dilations[0] - 1)
2612 + padding[0]
2613 + padding[1]
2614 ) // strides[0] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002615
Kevin Cheng550ccc52021-03-03 11:21:43 -08002616 w = (
2617 ifm.shape[2]
2618 - filter.shape[1]
2619 - (filter.shape[1] - 1) * (dilations[1] - 1)
2620 + padding[2]
2621 + padding[3]
2622 ) // strides[1] + 1
Eric Kunzee5e26762020-10-13 16:11:07 -07002623
2624 if h <= 0 or w <= 0:
2625 # Invalid test parameters?
2626 h = 0
2627 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002628 ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002629
2630 ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
2631
Kevin Cheng3a478572021-01-22 17:21:02 -08002632 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002633 out_dtype = DType.INT32
2634 elif ifm.dtype == DType.INT16:
2635 out_dtype = DType.INT48
2636 elif ifm.dtype == DType.FLOAT:
2637 out_dtype = DType.FLOAT
2638 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002639 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002640
Kevin Cheng550ccc52021-03-03 11:21:43 -08002641 return ser.addOutput(ofm_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002642
2643 @staticmethod
2644 def pool2dOp(ser, ifm, kernel, stride, pad):
2645 # input: NHWC
2646 h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
2647 w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
2648
2649 if h <= 0 or w <= 0:
2650 # Invalid test parameters?
2651 h = 0
2652 w = 0
Kevin Cheng550ccc52021-03-03 11:21:43 -08002653 ser.setExpectedFailure(True, "Invalid combination of pooling parameters")
Eric Kunzee5e26762020-10-13 16:11:07 -07002654
2655 ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
Kevin Cheng550ccc52021-03-03 11:21:43 -08002656 return ser.addOutput(ofm_shape, ifm.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002657
2658 @staticmethod
2659 def fullyConnectedOp(ser, input, filter):
2660 # input: N, IC
2661 # filter: OC, IC
2662 # output: N, OC
2663
2664 output_shape = [input.shape[0], filter.shape[0]]
2665
Kevin Cheng3a478572021-01-22 17:21:02 -08002666 if input.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002667 out_dtype = DType.INT32
2668 elif input.dtype == DType.INT16:
2669 out_dtype = DType.INT48
2670 elif input.dtype == DType.FLOAT:
2671 out_dtype = DType.FLOAT
2672 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002673 raise Exception("Unsupported input dtype: {}".format(input.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002674
Kevin Cheng550ccc52021-03-03 11:21:43 -08002675 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002676
2677 @staticmethod
2678 def matmulOp(ser, a, b):
Kevin Cheng2d60f002021-06-09 14:18:32 -07002679 # a: N, H, C
2680 # b: N, C, W
2681 # out: N, H, W
Eric Kunzee5e26762020-10-13 16:11:07 -07002682
Kevin Cheng2d60f002021-06-09 14:18:32 -07002683 output_shape = [a.shape[0], a.shape[1], b.shape[2]]
Eric Kunzee5e26762020-10-13 16:11:07 -07002684
Kevin Cheng3a478572021-01-22 17:21:02 -08002685 if a.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002686 out_dtype = DType.INT32
2687 elif a.dtype == DType.INT16:
2688 out_dtype = DType.INT48
2689 elif a.dtype == DType.FLOAT:
2690 out_dtype = DType.FLOAT
2691 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002692 raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002693
Kevin Cheng550ccc52021-03-03 11:21:43 -08002694 return ser.addOutput(output_shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002695
2696 @staticmethod
2697 def concatOp(ser, a, b, axis):
2698
2699 output_shape = a.shape.copy()
2700 output_shape[axis] = a.shape[axis] + b.shape[axis]
2701
Kevin Cheng550ccc52021-03-03 11:21:43 -08002702 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002703
2704 @staticmethod
2705 def padOp(ser, a, padding):
2706
2707 output_shape = a.shape.copy()
2708
2709 for i in range(len(output_shape)):
2710 output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
2711
Kevin Cheng550ccc52021-03-03 11:21:43 -08002712 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002713
2714 @staticmethod
2715 def reshapeOp(ser, a, shape):
2716 output_shape = shape.copy()
2717
2718 totalElements = 1
2719 for i in a.shape:
2720 totalElements *= i
2721
2722 # If there are any -1 elements, figure out what that dimension must be
2723 totalOutputElements = 1
2724 for i in output_shape:
2725 if i != -1:
2726 totalOutputElements *= i
2727
2728 # And fill it in
2729 for i in range(len(output_shape)):
2730 if output_shape[i] == -1:
2731 output_shape[i] = totalElements // totalOutputElements
2732
Kevin Cheng550ccc52021-03-03 11:21:43 -08002733 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002734
2735 @staticmethod
2736 def sliceOp(ser, a, begin, size):
2737
2738 output_shape = size.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002739 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002740
2741 @staticmethod
2742 def tileOp(ser, a, multiples):
2743
2744 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002745 assert len(multiples) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002746
2747 for i in range(len(output_shape)):
2748 output_shape[i] = a.shape[i] * multiples[i]
2749
Kevin Cheng550ccc52021-03-03 11:21:43 -08002750 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002751
2752 @staticmethod
2753 def transposeOp(ser, a, perms):
2754 output_shape = a.shape.copy()
Kevin Cheng550ccc52021-03-03 11:21:43 -08002755 assert len(perms) == len(output_shape)
Eric Kunzee5e26762020-10-13 16:11:07 -07002756
2757 for i in range(len(output_shape)):
2758 output_shape[i] = a.shape[perms[i]]
2759
Kevin Cheng550ccc52021-03-03 11:21:43 -08002760 return ser.addOutput(output_shape, a.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002761
2762 @staticmethod
Kevin Cheng77d0f762020-11-24 10:26:32 -08002763 def gatherOp(ser, values, indices):
2764 assert len(values.shape) == 3
2765 assert len(indices.shape) == 2
2766 assert values.shape[0] == indices.shape[0]
Eric Kunzee5e26762020-10-13 16:11:07 -07002767
Kevin Cheng77d0f762020-11-24 10:26:32 -08002768 output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
2769
Kevin Cheng550ccc52021-03-03 11:21:43 -08002770 return ser.addOutput(output_shape, values.dtype)
Kevin Cheng77d0f762020-11-24 10:26:32 -08002771
2772 @staticmethod
2773 def scatterOp(ser, values_in, indices, input):
2774 assert len(values_in.shape) == 3
2775 assert len(indices.shape) == 2
2776 assert len(input.shape) == 3
Kevin Cheng550ccc52021-03-03 11:21:43 -08002777 assert values_in.shape[0] == indices.shape[0] # N
2778 assert input.shape[1] == indices.shape[1] # W
2779 assert values_in.shape[2] == input.shape[2] # C
Kevin Cheng77d0f762020-11-24 10:26:32 -08002780
2781 output_shape = values_in.shape
2782
Kevin Cheng550ccc52021-03-03 11:21:43 -08002783 return ser.addOutput(output_shape, values_in.dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002784
2785 @staticmethod
Jeremy Johnsonf54d8a22021-07-20 16:01:06 +01002786 def tableOp(ser, input, table_dtype):
2787 # Same shape as the input, but dtype dependent on table dtype
2788 assert table_dtype == DType.INT16 or table_dtype == DType.INT8
2789 output_dtype = DType.INT32 if table_dtype == DType.INT16 else DType.INT8
2790 return ser.addOutput(input.shape, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002791
2792 @staticmethod
Kevin Cheng550ccc52021-03-03 11:21:43 -08002793 def resizeOp(
2794 ser,
2795 input,
2796 mode,
2797 stride,
2798 offset,
2799 shift,
2800 stride_fp,
2801 offset_fp,
2802 output_dims,
2803 input_dtype,
2804 output_dtype,
2805 ):
Eric Kunzee5e26762020-10-13 16:11:07 -07002806
2807 output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
2808
Kevin Cheng77d0f762020-11-24 10:26:32 -08002809 if input_dtype == DType.FLOAT:
2810 if stride_fp[0] <= 0 or stride_fp[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002811 ser.setExpectedFailure(True, "Negative or zero stride")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002812 else:
2813 if stride[0] <= 0 or stride[1] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002814 ser.setExpectedFailure(True, "Negative or zero stride")
Eric Kunzee5e26762020-10-13 16:11:07 -07002815
Kevin Chengaee1fac2020-11-11 13:54:06 -08002816 if mode == ResizeMode.BILINEAR:
2817 if input_dtype == DType.INT8:
2818 if output_dtype != DType.INT32:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002819 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002820 elif input_dtype == DType.INT16:
2821 if output_dtype != DType.INT48:
Kevin Cheng989cb052021-04-28 16:29:44 -07002822 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002823 elif input_dtype == DType.FLOAT:
2824 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002825 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002826 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002827 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002828
2829 elif mode == ResizeMode.NEAREST:
2830 if input_dtype == DType.INT8:
2831 if output_dtype != DType.INT8:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002832 ser.setExpectedFailure(True, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002833 elif input_dtype == DType.INT16:
2834 if output_dtype != DType.INT16:
Kevin Cheng989cb052021-04-28 16:29:44 -07002835 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Cheng77d0f762020-11-24 10:26:32 -08002836 elif input_dtype == DType.FLOAT:
2837 if output_dtype != DType.FLOAT:
Kevin Cheng989cb052021-04-28 16:29:44 -07002838 ser.setExpectedFailure(true, "Invalid output data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002839 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002840 ser.setExpectedFailure(true, "Invalid input data type")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002841
2842 else:
Kevin Cheng989cb052021-04-28 16:29:44 -07002843 ser.setExpectedFailure(true, "Invalid resize mode")
Kevin Chengaee1fac2020-11-11 13:54:06 -08002844
Kevin Cheng550ccc52021-03-03 11:21:43 -08002845 return ser.addOutput(output_dims, output_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002846
2847 @staticmethod
2848 def typeConversionOp(ser, val, out_dtype):
Kevin Cheng550ccc52021-03-03 11:21:43 -08002849 return ser.addOutput(val.shape, out_dtype)
Eric Kunzee5e26762020-10-13 16:11:07 -07002850
2851 @staticmethod
2852 def transposeConv2DOp(ser, ifm, output_shape):
Kevin Cheng3a478572021-01-22 17:21:02 -08002853 if ifm.dtype == DType.INT8:
Eric Kunzee5e26762020-10-13 16:11:07 -07002854 out_dtype = DType.INT32
2855 elif ifm.dtype == DType.INT16:
2856 out_dtype = DType.INT48
2857 elif ifm.dtype == DType.FLOAT:
2858 out_dtype = DType.FLOAT
2859 else:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002860 raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
Eric Kunzee5e26762020-10-13 16:11:07 -07002861
2862 if output_shape[1] <= 0 or output_shape[2] <= 0:
Kevin Cheng550ccc52021-03-03 11:21:43 -08002863 ser.setExpectedFailure(True, "Negative output shape")
Eric Kunzee5e26762020-10-13 16:11:07 -07002864
Kevin Cheng550ccc52021-03-03 11:21:43 -08002865 return ser.addOutput(output_shape, out_dtype)