blob: e2a69f163af442e4df47bc990d811ddff4dba0a6 [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_error_if import ErrorIf
9from generator.tosa_error_if import TosaErrorIfArgGen
James Ward8b390432022-08-12 20:48:56 +010010from generator.tosa_utils import get_accum_dtype_from_tgTypes
11from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010012from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from serializer.tosa_serializer import DTypeNames
14from tosa.DType import DType
15from tosa.Op import Op
16from tosa.ResizeMode import ResizeMode
17
18# DTypeNames, DType, Op and ResizeMode are convenience variables to the
19# flatc-generated types that should be enums, but aren't
20
21
22class TosaQuantGen:
23 """QuantizedInfo random generator helper functions.
24
25 Specify with 'qgen': in the operator defintion.
26 """
27
28 def __init__(self):
29 pass
30
31 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000032 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010033
34 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010035 if testGen.args.zeropoint is not None:
36 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010037 return testGen.randInt(-128, 128)
38 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010039 if testGen.args.zeropoint is not None:
40 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010041 return testGen.randInt(0, 256)
42 elif error_name in [
43 ErrorIf.InputZeroPointNotZero,
44 ErrorIf.WeightZeroPointNotZero,
45 ErrorIf.OutputZeroPointNotZero,
46 ]:
47 zero_point = testGen.randInt(-128, 128)
48 if zero_point == 0:
49 zero_point = 1
50 return zero_point
51 return 0
52
53 @staticmethod
54 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010055 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000056 qinfo = [
57 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
58 TosaQuantGen.getZeroPoint(testGen, dtype),
59 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010060 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000061 qinfo = [
62 TosaQuantGen.getZeroPoint(testGen, dtype),
63 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
64 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010065 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000066 qinfo = [
67 TosaQuantGen.getZeroPoint(testGen, dtype),
68 TosaQuantGen.getZeroPoint(testGen, dtype),
69 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010070 return qinfo
71
72 @staticmethod
73 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010074 if isinstance(dtype_or_dtypeList, list):
75 # a list of [input, weights, accumulator] dtypes
76 dtypeList = dtype_or_dtypeList
77 else:
78 # an int, [input, weights, accumulator] dtypes are the same
79 dtypeList = [dtype_or_dtypeList] * 3
80
81 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000082 qinfo = [
83 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
84 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
85 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010086 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000087 qinfo = [
88 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
89 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
90 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010091 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000092 qinfo = [
93 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
94 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
95 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010096 return qinfo
97
98 @staticmethod
99 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100100 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000101 qinfo = [
102 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
103 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
104 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100105 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000106 qinfo = [
107 TosaQuantGen.getZeroPoint(testGen, dtype),
108 TosaQuantGen.getZeroPoint(testGen, dtype),
109 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100110 return qinfo
111
112 @staticmethod
113 def computeMultiplierAndShift(scaleFp, scale32):
114 # Derived from computeMultiplierAndShiftTosaScale32
115 # Provide a floating-point scaling factor and the scale32 parameter
116 # to compute the multiplier and shift
117
118 if scale32:
119 scaleBits = 31
120 else:
121 scaleBits = 15
122
123 m, shift = math.frexp(scaleFp)
124
125 if scaleFp < 0.0:
126 m = -m
127
128 multiplier = round(m * (1 << scaleBits))
129 assert multiplier <= (1 << scaleBits)
130
131 if multiplier == (1 << scaleBits):
132 multiplier = multiplier // 2
133 shift = shift + 1
134
135 shift = (-shift) + scaleBits
136 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
137 # scaleFp, scaleBits, m, multiplier, shift))
138
139 # Adjust multiplier such that shift is in allowed value range.
140 if shift == 0:
141 multiplier = multiplier // 4
142 shift = shift + 2
143 elif shift == 1:
144 multiplier = multiplier // 2
145 shift = shift + 1
146 elif shift == 63:
147 multiplier = multiplier * 2
148 shift = shift - 1
149
150 assert multiplier <= (1 << scaleBits)
151 assert shift >= 2 and shift <= 62
152
153 return multiplier, shift
154
155
156class TosaTensorGen:
157 """Tensor generators create a shape list for the placeholder and const tensor
158 data operands for the operator.
159
160 The actual random data is generated separately for each test.
161 """
162
163 def __init__(self):
164 pass
165
166 @staticmethod
167 def tgBasic(testGen, opName, rank, error_name=None):
168 pl, const = opName["operands"]
169 shape = testGen.makeShape(rank)
170
171 # Constrict the overall size of the shape when creating ERROR_IF tests
172 if error_name:
173 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
174
175 shape_list = []
176 for i in range(pl + const):
177 shape_list.append(shape.copy())
178
Luke Huttona4e48ca2023-02-22 11:53:48 +0000179 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100180 if error_name == ErrorIf.RankMismatch:
181 if rank == 1 and i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
183 elif i != 1:
184 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
185
186 return shape_list
187
188 @staticmethod
189 def tgNHWC(testGen, opName, rank, error_name=None):
190 pl, const = opName["operands"]
191
192 if error_name != ErrorIf.WrongRank:
193 assert rank == 4
194
195 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000196 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100197
198 # Constrict the overall size of the shape when creating ERROR_IF tests
199 if error_name and error_name != ErrorIf.MaxDimExceeded:
200 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
201
202 shape_list = []
203 for i in range(pl + const):
204 shape_list.append(shape.copy())
205
206 return shape_list
207
208 @staticmethod
209 def tgScatter(testGen, opName, rank, error_name=None):
210 pl, const = opName["operands"]
211
212 assert pl == 2
213 assert const == 0
214 if error_name != ErrorIf.WrongRank:
215 assert rank == 3
216
217 values_in_shape = testGen.makeShape(rank)
218
219 # ignore max batch size if target shape is set
220 if testGen.args.max_batch_size and not testGen.args.target_shapes:
James Ward30124a82023-02-02 14:56:33 +0000221 values_in_shape[0] = min(values_in_shape[0], testGen.args.max_batch_size)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100222
223 W = testGen.randInt(
224 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
225 )
226 # Constrict W if one dimension is too large to keep tensor size reasonable
227 if max(values_in_shape) > 5000:
228 W = testGen.randInt(0, 16)
229
230 input_shape = [values_in_shape[0], W, values_in_shape[2]]
231
232 shape_list = []
233 shape_list.append(values_in_shape.copy())
234 shape_list.append(input_shape.copy())
235
236 return shape_list
237
238 @staticmethod
239 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
240 shape = testGen.makeShape(rank)
241
242 pl, const = op["operands"]
243
244 shape_list = []
245
246 # Choose one of the inputs to broadcast
247 # Note: Simplifies OutputShaper code if we don't change first shape for errors
248 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
249 for i in range(pl + const):
250 shape_bcast = shape.copy()
251
252 # If the chosen input, pick a random index to broadcast
253 if i == bcast_idx:
254 fuzz_idx = testGen.randInt(0, rank)
255 if error_name == ErrorIf.DimensionMismatch:
256 shape_bcast[fuzz_idx] += 1
257 elif error_name == ErrorIf.RankMismatch:
258 # Add one rank to the shape (or more for rank of 1)
259 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
260 shape_bcast = np.concatenate(
261 (shape_bcast, testGen.makeShape(extra_ranks))
262 )
263 if rank != 1:
264 # Either keep the extra rank, or remove it
265 new_len = testGen.rng.choice([-2, len(shape_bcast)])
266 shape_bcast = shape_bcast[:new_len]
267 else:
268 shape_bcast[fuzz_idx] = 1
269
270 shape_list.append(shape_bcast)
271
272 return shape_list
273
274 @staticmethod
275 def tgConv2D(testGen, op, rank, error_name=None):
276 pl, const = op["operands"]
277
278 if error_name != ErrorIf.WrongRank:
279 assert rank == 4
280
281 # IFM dimensions are NHWC
282 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000283 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100284
285 # Constrict the overall size of the shape when creating ERROR_IF tests
286 if error_name:
287 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
288 ifm_shape, max_dim=24, max_items=10000
289 )
290
291 # Get the filter height/width from the operator parameters
292 filter_hw = op["filter"]
293
294 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000295 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100296
297 # The filter dimensions are OHWI
298 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
299
300 # The bias is OC
301 bias_shape = np.asarray([ofm_depth])
302
303 return [ifm_shape, filter_shape, bias_shape]
304
305 @staticmethod
306 def tgConv3D(testGen, op, rank, error_name=None):
307 pl, const = op["operands"]
308
309 if error_name != ErrorIf.WrongRank:
310 assert rank == 5
311
312 # IFM dimensions are NDHWC
313 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000314 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100315
316 # Constrict the overall size of the shape when creating ERROR_IF tests
317 if error_name:
318 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
319 ifm_shape, max_dim=24, max_items=10000
320 )
321
322 # Get the filter depth/height/width from the operator parameters
323 filter_dhw = op["filter"]
324
325 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000326 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100327
328 # The filter dimensions are ODHWI
329 filter_shape = np.asarray(
330 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
331 )
332
333 # The bias is OC
334 bias_shape = np.asarray([ofm_channel])
335
336 return [ifm_shape, filter_shape, bias_shape]
337
338 @staticmethod
339 def tgTransposeConv2D(testGen, op, rank, error_name=None):
340 pl, const = op["operands"]
341
342 if error_name != ErrorIf.WrongRank:
343 assert rank == 4
344
345 # IFM dimensions are NHWC
346 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000347 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100348
349 # Constrict the overall size of the shape when creating ERROR_IF tests
350 if error_name:
351 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
352 ifm_shape, max_dim=24, max_items=10000
353 )
354
355 # Get the filter height/width from the operator parameters
356 filter_hw = op["filter"]
357
358 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000359 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100360
361 # The filter dimensions are OHWI
362 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
363
364 # The bias is OC
365 bias_shape = np.asarray([ofm_depth])
366
367 return [ifm_shape, filter_shape, bias_shape]
368
369 @staticmethod
370 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
371 pl, const = op["operands"]
372
373 if error_name != ErrorIf.WrongRank:
374 assert rank == 4
375 assert pl == 1 and const == 2
376
377 # IFM dimensions are NHWC
378 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000379 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100380
381 # Constrict the overall size of the shape when creating ERROR_IF tests
382 if error_name:
383 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
384 ifm_shape, max_dim=24, max_items=10000
385 )
386
387 # Get the filter height/width from the operator parameters
388 # Filter is KH, HW, C, M
389 filter_hw = op["filter"]
390
391 # Generate a random OFM depth, but don't let it get too big because
392 # the output depth is M * C
393 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000394 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100395 ) + 1
396
397 # The filter dimensions are HWCM
398 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
399
400 # The bias is M * C
401 bias_shape = np.asarray([ifm_shape[3] * filter_m])
402
403 return [ifm_shape, filter_shape, bias_shape]
404
405 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000406 def tgFFT2d(testGen, op, rank, error_name=None):
407 pl, const = op["operands"]
408
409 if error_name != ErrorIf.WrongRank:
410 assert rank == 3
411 assert pl == 2 and const == 0
412
413 # IFM dimensions are NHW
414 ifm_shape = testGen.makeShape(rank)
415
416 # Select nearest lower power of two from input height and width
417 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
418 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
419
420 # Constrict the overall size of the shape when creating ERROR_IF tests
421 if error_name:
422 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
423
424 # Generate an invalid kernel that is not a power of two
425 if error_name == ErrorIf.KernelNotPowerOfTwo:
426 inc_h = 2 if ifm_shape[1] == 1 else 1
427 inc_w = 2 if ifm_shape[2] == 1 else 1
428 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
429 selected_inc = testGen.rng.choice(inc_choices)
430 ifm_shape[1] += selected_inc[0]
431 ifm_shape[2] += selected_inc[1]
432
433 ifm_shape = testGen.constrictBatchSize(ifm_shape)
434
435 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
436 if error_name == ErrorIf.FFTInputShapeMismatch:
437 modify_shape = testGen.rng.choice([0, 1])
438 # Only modify kernel (H, W)
439 modify_dim = testGen.rng.choice([1, 2])
440 ifm_shapes[modify_shape][modify_dim] *= 2
441
442 return [ifm_shapes[0], ifm_shapes[1]]
443
444 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000445 def tgRFFT2d(testGen, op, rank, error_name=None):
446 pl, const = op["operands"]
447
448 if error_name != ErrorIf.WrongRank:
449 assert rank == 3
450 assert pl == 1 and const == 0
451
452 # IFM dimensions are NHW
453 ifm_shape = testGen.makeShape(rank)
454
455 # Select nearest lower power of two from input height and width
456 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
457 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
458
459 # Constrict the overall size of the shape when creating ERROR_IF tests
460 if error_name:
461 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
462
463 # Generate an invalid kernel that is not a power of two
464 if error_name == ErrorIf.KernelNotPowerOfTwo:
465 # We must increment by 2 if current size is 1
466 inc_h = 2 if ifm_shape[1] == 1 else 1
467 inc_w = 2 if ifm_shape[2] == 1 else 1
468 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
469 selected_inc = testGen.rng.choice(inc_choices)
470 ifm_shape[1] += selected_inc[0]
471 ifm_shape[2] += selected_inc[1]
472
James Ward30124a82023-02-02 14:56:33 +0000473 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000474
475 return [ifm_shape]
476
477 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100478 def tgFullyConnected(testGen, op, rank, error_name=None):
479 pl, const = op["operands"]
480
481 if error_name != ErrorIf.WrongRank:
482 assert rank == 2
483
484 input_shape = testGen.makeShape(rank)
485
486 # Constrict the overall size of the shape when creating ERROR_IF tests
487 if error_name:
488 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
489
490 filter_oc = testGen.rng.integers(
491 low=testGen.args.tensor_shape_range[0],
492 high=testGen.args.tensor_shape_range[1],
493 size=1,
494 )[0]
495 filter_shape = np.asarray([filter_oc, input_shape[1]])
496
497 bias_shape = np.asarray([filter_oc])
498
499 return [input_shape, filter_shape, bias_shape]
500
501 @staticmethod
502 def tgMatmul(testGen, op, rank, error_name=None):
503 pl, const = op["operands"]
504
505 if error_name != ErrorIf.WrongRank:
506 assert rank == 3
507 assert pl == 2 and const == 0
508
509 a_shape = testGen.makeShape(rank)
510
511 # Constrict the overall size of the shape when creating ERROR_IF tests
512 if error_name:
513 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
514
515 # Get a random number for b_oc even if target shape is defined
516 b_oc = np.int32(
517 testGen.rng.integers(
518 low=testGen.args.tensor_shape_range[0],
519 high=testGen.args.tensor_shape_range[1],
520 size=1,
521 )
522 )[0]
523 # If N or H is large let b_oc be 1 to reduce output tensor size
524 if max(a_shape) > 1000:
525 b_oc = 1
526
527 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
528 return [a_shape, b_shape]
529
530 @staticmethod
531 def tgConcat(testGen, opName, rank, error_name=None):
532 pl, const = opName["operands"]
533 shape = testGen.makeShape(rank)
534
535 # Create extra tensors to concat.
536 # Take into account value of pl when getting maximum number of concats
537 num_tensors = testGen.randInt(0, 4)
538 shape_list = []
539 for i in range(pl + const + num_tensors):
540 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
541 remove = testGen.rng.choice([True, False])
542 wrongShape = shape.copy()
543
544 if remove and len(shape) > 1:
545 wrongShape = wrongShape[1:]
546 else:
547 wrongShape = list(wrongShape)
548 wrongShape.append(testGen.rng.integers(1, 10))
549
550 shape_list.append(wrongShape)
551 else:
552 shape_list.append(shape.copy())
553
554 return shape_list
555
556 @staticmethod
557 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
558 if error_name in [
559 ErrorIf.AxisSmallerZero,
560 ErrorIf.AxisLargerRank,
561 ErrorIf.ConcatInputRankMismatch,
562 ]:
563 return shapeList
564
565 # Split concat shape along axis to allow for multiple const inputs
566 # without making too many large tensors
567 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
568 # If axis can't be split we still need to invalidate other dimensions
569 if error_name == ErrorIf.ConcatInputDimMismatch:
570 for shape in shapeList[1:]:
571 # Negative test shapeLists are created individually for each test,
572 # so no need to copy the shape before altering it.
573 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
574 return shapeList
575
576 # Create copy of shape we are going to split (so we don't alter shapeList)
577 shape = shapeList[0].copy()
578 # Add original shape as first input
579 new_shapeList = [shape.copy()]
580 length_on_axis = shape[axis]
581 remaining_length = length_on_axis
582 for i in range(len(shapeList) - 2):
583 # Calculate split on axis and remaining value
584 split_shape_val = int(shape[axis] / 2)
585 remaining_length = remaining_length - split_shape_val
586
587 # Append new shape, and set remaining shape
588 shape[axis] = split_shape_val
589 new_shapeList.append(shape.copy())
590
591 # invalidate dimensions
592 if error_name == ErrorIf.ConcatInputDimMismatch:
593 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
594 else:
595 shape[axis] = remaining_length
596
597 if i == len(shapeList) - 3:
598 new_shapeList.append(shape.copy())
599
600 return new_shapeList
601
602
603class TosaTensorValuesGen:
604 """Tensor Value generators create the random data for each test."""
605
606 def __init__(self):
607 pass
608
609 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000610 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100611 pCount, cCount = op["operands"]
612
613 tens = []
614 tens.extend(
615 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
616 )
617 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
618
619 return tens
620
621 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000622 def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100623 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100624 pCount, cCount = op["operands"]
625 assert (
626 pCount == 1 and cCount == 0
627 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100628 # Must create tensors with values within accumulator (int32) negatable
629 # range
630 max_val = (1 << 31) - 1
631 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100632 arr = np.int32(
633 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
634 )
635 placeholders = []
636 placeholders.append(
637 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
638 )
639 return placeholders
640 else:
641 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000642 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100643 )
644
645 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000646 def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100647 if dtypeList[0] == DType.INT32 and error_name is None:
648 # Make sure the operation does not cause value saturation - where
649 # the number wraps due to limited number of bits to store the answer
650 pCount, cCount = op["operands"]
651 assert (
652 pCount == 2 and cCount == 0
653 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
654 placeholders = []
655 add = op["op"] == Op.ADD
656 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
657 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
658 if add:
659 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
660 else:
661 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
662
663 # Work out the saturation limits
664 max_i32 = (1 << 31) - 1
665 min_i32 = -(1 << 31)
666 max_arr = np.full(shapeList[1], max_i32)
667 min_arr = np.full(shapeList[1], min_i32)
668
669 # Find how much values exceed the maximum/minimums
670 sat_max_arr = np.maximum(res_arr - max_arr, 0)
671 sat_min_arr = np.minimum(res_arr - min_arr, 0)
672
673 if not add:
674 # Swap saturation values and negate values as we need to perform opposite operations
675 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
676
677 # Create new array of unsaturated values by clipping values as needed
678 b_unsat_arr = b_arr
679 if (sat_max_arr != 0).any():
680 # Clip values that cause saturation
681 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
682 # Reduce axes in unsaturated tensor to match original tensor
683 for axis, dim in enumerate(b_arr.shape):
684 if dim != b_unsat_arr.shape[axis]:
685 assert (
686 dim == 1
687 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
688 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
689
690 if (sat_min_arr != 0).any():
691 # Clip values that cause saturation
692 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
693 # Reduce axes in unsaturated tensor to match original tensor
694 for axis, dim in enumerate(b_arr.shape):
695 if dim != b_unsat_arr.shape[axis]:
696 assert (
697 dim == 1
698 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
699 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
700
701 placeholders.append(
702 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
703 )
704 placeholders.append(
705 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
706 )
707
708 return placeholders
709 else:
710 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000711 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100712 )
713
714 @staticmethod
715 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000716 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100717 ):
718 if dtypeList[0] in (
719 DType.INT32,
720 DType.INT16,
721 DType.INT8,
722 ):
723 # Limit input tensors with cond_if_binary or while_loop to stop
724 # saturation of add/sub ops with int32 and keep all logical shift
725 # values between 0 to 31 for int16 or int8
726 pCount, cCount = op["operands"]
727 pRemain = pCount
728 placeholders = []
729 for idx, shape in enumerate(shapeList[:]):
730 if dtypeList[0] == DType.INT32:
731 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
732 else:
733 arr = np.int32(
734 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
735 )
736 if pRemain > 0:
737 placeholders.append(
738 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
739 )
740 pRemain -= 1
741 else:
742 placeholders.append(
743 testGen.ser.addConst(shape, dtypeList[idx], arr)
744 )
745
746 return placeholders
747 else:
748 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000749 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100750 )
751
752 @staticmethod
753 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000754 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100755 ):
756 pCount, cCount = op["operands"]
757 # Force value of operand[1] to be within [0, num_bits]
758 assert (
759 pCount == 2 and cCount == 0
760 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
761
762 placeholders = []
763 for idx, shape in enumerate(shapeList[:]):
764 if idx == 1:
765 if dtypeList[idx] == DType.INT8:
766 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
767 elif dtypeList[idx] == DType.INT16:
768 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
769 elif dtypeList[idx] == DType.INT32:
770 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
771 elif error_name == ErrorIf.WrongInputType:
772 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
773 else:
774 raise Exception("OpArithmeticRightShift: invalid input dtype")
775 else:
776 arr = testGen.getRandTensor(shape, dtypeList[idx])
777 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
778
779 return placeholders
780
781 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000782 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100783 # Set datatype of condition tensor to boolean
784 dtypeList[0] = DType.BOOL
785
786 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000787 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100788 )
789
790 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000791 def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100792 if error_name is None:
793 pCount, cCount = op["operands"]
794 assert (
795 pCount == 2 and cCount == 0
796 ), "Op.INTDIV must have 2 placeholders, 0 consts"
797
798 placeholders = []
799
800 # Two invalid cases for Op.INTDIV:
801 # 1. divisor == 0
802 # 2. dividend == -(1<<31) and divisor == -1
803 while True:
804 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
805 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
806
807 if (divisor_arr == 0).any():
808 continue
809
810 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
811 continue
812
813 break
814
815 placeholders.append(
816 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
817 )
818 placeholders.append(
819 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
820 )
821
822 return placeholders
823 else:
824 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000825 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100826 )
827
828 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000829 def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100830 if error_name is None:
831 pCount, cCount = op["operands"]
832 assert (
833 pCount == 2 and cCount == 0
834 ), "Op.MUL must have 2 placeholders, 0 consts"
835
836 tens = []
James Ward24dbc422022-10-19 12:20:31 +0100837 if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100838 tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
839 else:
840 placeholders = []
841
842 # Make sure multiply result in int32 range
843 shift = testArgs[0]
844 if dtypeList[0] == DType.INT8:
845 num_bits = 8
846 elif dtypeList[0] == DType.INT16:
847 num_bits = 16
848 elif dtypeList[0] == DType.INT32:
849 num_bits = 32
850 elif error_name == ErrorIf.WrongInputType:
851 num_bits = 8
852 else:
853 raise Exception("OpMul: invalid input dtype")
854
855 for idx, shape in enumerate(shapeList[:]):
856 low = -(2 ** (num_bits - 1))
857 high = (2 ** (num_bits - 1)) - 1
858
859 a_arr = np.int32(
860 testGen.rng.integers(low=low, high=high, size=shapeList[0])
861 )
862 b_arr = np.int32(
863 testGen.rng.integers(low=low, high=high, size=shapeList[1])
864 )
865
866 i = 0
867 while True:
868
869 a_arr_64 = a_arr.astype(np.int64)
870 b_arr_64 = b_arr.astype(np.int64)
871
872 if shift > 0:
873 rounding = 1 << (shift - 1)
874 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
875 else:
876 result_arr = a_arr_64 * b_arr_64
877
878 if (result_arr > -(2**31)).all() and (
879 result_arr <= ((2**31) - 1)
880 ).all():
881 break
882
883 i = i + 1
884 a_arr = a_arr // 2
885 b_arr = b_arr // 2
886
887 placeholders.append(
888 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
889 )
890 placeholders.append(
891 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
892 )
893
894 tens.extend(placeholders)
895
896 return tens
897 else:
898 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000899 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100900 )
901
902 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000903 def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100904 count = len(shapeList) - testGen.args.num_const_inputs_concat
905 if count < 1:
906 count = 1
907 if testGen.args.num_const_inputs_concat == 0:
908 count = len(shapeList)
909
910 # Ensure axis is an int
911 testArgs[0] = int(testArgs[0])
912
913 shapeList = TosaTensorGen.tgConcatConstInput(
914 testGen, shapeList, testArgs[0], error_name
915 )
916
917 tens = []
918 tens.extend(
919 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
920 )
921 tens.extend(testGen.buildConstTensors(shapeList[count:], dtypeList[count:]))
922
923 return tens
924
925 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000926 def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100927 pCount, cCount = op["operands"]
928 assert (
929 pCount == 2 and cCount == 0
930 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
931 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
932 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
933 placeholders = []
934 placeholders.append(
935 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
936 )
937 placeholders.append(
938 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
939 )
940
941 return placeholders
942
943 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000944 def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100945 if error_name is None:
946 pCount, cCount = op["operands"]
947 assert (
948 pCount == 2 and cCount == 0
949 ), "Op.EQUAL must have 2 placeholders, 0 consts"
950 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
951 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
952 # Using random numbers means that it will be very unlikely that
953 # there are any matching (equal) values, therefore force that
954 # there are twice the number of matching values as the tensor rank
955 for num in range(0, len(shapeList[0]) * 2):
956 a_index = []
957 b_index = []
958 # Choose an index in each axis for the whole shape
959 for axis in range(0, len(shapeList[0])):
960 # Index can be up to the largest dimension in both shapes
961 index = np.int32(
962 testGen.rng.integers(
963 0, max(shapeList[0][axis], shapeList[1][axis])
964 )
965 )
966 # Reduce the index down to a shape's dim for broadcasting
967 a_index.append(min(shapeList[0][axis] - 1, index))
968 b_index.append(min(shapeList[1][axis] - 1, index))
969
970 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
971
972 placeholders = []
973 placeholders.append(
974 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
975 )
976 placeholders.append(
977 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
978 )
979 return placeholders
980 else:
981 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000982 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100983 )
984
985 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000986 def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100987 if dtypeList[0] == DType.INT32:
988 pCount, cCount = op["operands"]
989 assert (
990 pCount == 1 and cCount == 0
991 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
992 # Limit values so that the sum cannot exceed the range of an int32 during
993 # summation of any axis
994 range_val = int((1 << 31) / max(shapeList[0]))
995 values_arr = np.int32(
996 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
997 )
998 placeholders = []
999 placeholders.append(
1000 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1001 )
1002 return placeholders
1003 else:
1004 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001005 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001006 )
1007
1008
1009class TosaArgGen:
1010 """Argument generators create exhaustive or random lists of attributes for
1011 operators that take attributes or other parameters.
1012
1013 The return value is a list of (descriptive_name, [arglist]) tuples where
1014 the descriptive_name is appended to the test name and the arglist is expanded
1015 as arguments to the operator build function.
1016 """
1017
1018 def __init__(self):
1019 pass
1020
1021 @staticmethod
1022 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1023 """A trivial argument generator for operators that don't take any
1024 non-tensor arguments"""
1025 return [("", [])]
1026
1027 @staticmethod
1028 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1029 """Build the axis argument for operators that take a single axis"""
1030 axes = []
1031 shape = shapeList[0]
1032
1033 if error_name == ErrorIf.AxisSmallerZero:
1034 small_axis = testGen.rng.integers(-5, 0)
1035 axes.append(("axis{}".format(small_axis), [small_axis]))
1036 elif error_name == ErrorIf.AxisLargerRank:
1037 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
1038 axes.append(("axis{}".format(large_axis), [large_axis]))
1039 else:
1040 for a in range(0, len(shape)):
1041 axes.append(("axis{}".format(a), [a]))
1042
1043 return axes
1044
1045 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001046 def _calculate_sparsity(num_tests, sparsity_factor):
1047 sparsity = num_tests // sparsity_factor + 1
1048 # If there are only a small number of tests, just select them all
1049 if sparsity < 13:
1050 sparsity = 1
1051 # To get a variety of parameter combinations sparsity should not be a
1052 # multiple of 2, 3 or 5
1053 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1054 sparsity += 1
1055 return sparsity
1056
1057 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001058 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001059 arg_list = []
1060
1061 ifm_shape = shapeList[0]
1062 filter_shape = shapeList[1]
1063 # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
1064 k = [int(x) for x in opName.split("_")[-1].split("x")]
1065
James Ward8b390432022-08-12 20:48:56 +01001066 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1067
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001068 # Check the rank
1069 rank = 5 if opName.startswith("conv3d") else 4
1070 if error_name != ErrorIf.WrongRank:
1071 assert len(ifm_shape) == rank
1072 assert len(filter_shape) == rank
1073
1074 # kernel rank omits batch and channels
1075 k_rank = rank - 2
1076 assert len(k) == k_rank
1077
1078 # Generate comprehensive argument lists
1079 # - except for named errors, which use specific invalid value(s)
1080 if error_name == ErrorIf.PadSmallerZero:
1081 p_vals = [testGen.rng.choice(range(-5, 0))]
1082 else:
1083 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1084 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1085 if error_name == ErrorIf.StrideSmallerOne:
1086 # Can't use stride=0, as it is used to derive output shape, as a divisor
1087 s_vals = [testGen.rng.choice(range(-5, 0))]
1088 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001089 # Stride must be greater than 1 to force non-integer error
Jeremy Johnson93d43902022-09-27 12:26:14 +01001090 startStride = 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001091 s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001092 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1093 if error_name == ErrorIf.DilationSmallerOne:
1094 d_vals = [testGen.rng.choice(range(-5, 1))]
1095 else:
1096 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1097 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
1098
1099 if not error_name and testGen.args.oversize:
1100 # add some oversize argument values
1101 if max(ifm_shape) < 64:
1102 bigPadding = 9
1103 paddings.update(
1104 {x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))}
1105 )
1106 bigStride = 8
1107 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
1108 bigDilation = 7
1109 dilations.update(
1110 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1111 )
1112
1113 # There are too many parameter combinations, so generate them sparsely,
1114 # very sparse for negative tests
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001115 sparsity_factor = 2 if error_name else 120
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001116 sparsity = TosaArgGen._calculate_sparsity(
1117 len(paddings) * len(strides) * len(dilations), sparsity_factor
1118 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001119
1120 n = 0
1121 for s in sorted(list(strides)):
1122 for p in sorted(list(paddings)):
1123 for d in sorted(list(dilations)):
1124 if (
1125 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001126 # the padded shape must exceed the dilation * kernel to get a positive
1127 # sized output shape
1128 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k[0] - 1)
1129 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k[1] - 1)
1130 and (
1131 k_rank < 3
1132 or ((ifm_shape[3] - 1 + p[4] + p[5]) > d[2] * (k[2] - 1))
1133 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001134 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001135 remainders = []
1136 for index in range(k_rank):
1137 pad_offset = index * 2
1138 remainders.append(
1139 (
1140 ifm_shape[index + 1]
1141 - 1
1142 + p[pad_offset]
1143 + p[pad_offset + 1]
1144 - (k[index] - 1) * d[index]
1145 )
1146 % s[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001147 )
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001148 if (
1149 # the parameters must produce integer exact output
1150 error_name != ErrorIf.ConvOutputShapeNonInteger
1151 and max(remainders) == 0
1152 ) or (
1153 error_name == ErrorIf.ConvOutputShapeNonInteger
1154 and max(remainders) > 0
1155 ):
1156 arg_list.append(
1157 (
James Ward8b390432022-08-12 20:48:56 +01001158 "acc{}_st{}_pad{}_dilat{}".format(
1159 testGen.typeStr(accum_dtype),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001160 "".join([str(x) for x in s]),
1161 "".join([str(x) for x in p]),
1162 "".join([str(x) for x in d]),
1163 ),
James Ward8b390432022-08-12 20:48:56 +01001164 [accum_dtype, s, p, d],
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001165 )
1166 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001167 n += 1
1168
1169 return arg_list
1170
1171 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001172 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1173
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001174 assert isinstance(dtypes, list) or isinstance(
1175 dtypes, tuple
1176 ), f"{dtypes} unexpected"
1177 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01001178
1179 if error_name == ErrorIf.WrongOutputType:
1180 accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype)
1181 elif error_name == ErrorIf.WrongInputType:
1182 # Pick some potentially correct output dtype if input type is incorrect
1183 accum_dtype = DType.INT32
1184 else:
1185 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1186
1187 return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])]
1188
1189 @staticmethod
1190 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
1191 # Get valid accumulate type(s)
1192 if dtype == DType.INT8:
1193 accum_dtypes = [DType.INT32]
1194 elif dtype == DType.INT16:
1195 accum_dtypes = [DType.INT48]
1196 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001197 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001198 elif dtype == DType.BF16:
1199 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001200 elif dtype == DType.FP32:
1201 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001202 elif error_name is None:
1203 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
1204
1205 if error_name == ErrorIf.WrongOutputType:
1206 # Get incorrect output dtype for ErrorIf case
1207 accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)]
1208 elif error_name == ErrorIf.WrongInputType:
1209 # Pick some potentially correct output dtype if input type is incorrect
1210 accum_dtypes = [DType.INT32]
1211
1212 return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes]
1213
1214 @staticmethod
1215 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001216 arg_list = []
1217
1218 ifm_shape = shapeList[0]
1219 filter_shape = shapeList[1]
1220
James Ward8b390432022-08-12 20:48:56 +01001221 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1222
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001223 # Must be rank 4
1224 if error_name != ErrorIf.WrongRank:
1225 assert len(ifm_shape) == 4
1226 assert len(filter_shape) == 4
1227
1228 # Generate comprehensive argument lists
1229 # - except for named errors, which use specific invalid value(s)
Eric Kunzec1a97832022-07-01 16:56:09 -07001230 smallest_padding_size = -min(filter_shape[1], filter_shape[2]) + 1
1231 if error_name == ErrorIf.PadLargerEqualKernel:
1232 max_filter_size = -max(filter_shape[1], filter_shape[2])
1233 p_vals = [testGen.rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001234 else:
Eric Kunzec1a97832022-07-01 16:56:09 -07001235 p_vals = [
1236 x
1237 for x in range(smallest_padding_size, testGen.args.max_conv_padding + 1)
1238 ]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001239 paddings = {x for x in itertools.product(*([p_vals] * 4))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001240 if error_name == ErrorIf.StrideSmallerOne:
1241 # Can't use stride=0, as it is used to derive output shape, as a divisor
1242 s_vals = [testGen.rng.choice(range(-5, 0))]
1243 else:
1244 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
1245 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001246
Jeremy Johnson5860df62022-05-04 15:30:58 +01001247 if not error_name and testGen.args.oversize:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001248 # add some oversize argument values
1249 if max(ifm_shape) < 64:
1250 bigPadding = 9
1251 paddings.update(
Eric Kunzec1a97832022-07-01 16:56:09 -07001252 {
1253 x
1254 for x in itertools.product(
1255 *([[smallest_padding_size, bigPadding]] * 4)
1256 )
1257 }
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001258 )
1259 bigStride = 8
1260 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001261
1262 # There are too many parameter combinations, so generate them sparsely,
1263 # very sparse for negative tests
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001264 sparsity_factor = 2 if error_name else 10
TatWai Chong24594f52022-06-08 00:48:04 -07001265 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001266 # If there are only a small number of tests, just select them all
1267 if sparsity < 13:
1268 sparsity = 1
1269 # To get a variety of parameter combinations sparsity should not be a
1270 # multiple of 2, 3 or 5
1271 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1272 sparsity += 1
1273
1274 n = 0
1275 for s in sorted(list(strides)):
1276 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07001277 if n % sparsity == 0:
1278 # Determine the output shape
Eric Kunzec1a97832022-07-01 16:56:09 -07001279 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + filter_shape[1]
1280 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + filter_shape[2]
TatWai Chong24594f52022-06-08 00:48:04 -07001281 os = [ifm_shape[0], oh, ow, filter_shape[0]]
1282 arg_list.append(
1283 (
James Ward8b390432022-08-12 20:48:56 +01001284 "acc{}_st{}_pad{}_os{}".format(
1285 testGen.typeStr(accum_dtype),
TatWai Chong24594f52022-06-08 00:48:04 -07001286 "".join([str(x) for x in s]),
1287 "".join([str(x) for x in p]),
1288 "x".join([str(x) for x in os]),
1289 ),
James Ward8b390432022-08-12 20:48:56 +01001290 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001291 )
TatWai Chong24594f52022-06-08 00:48:04 -07001292 )
1293 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001294
1295 return arg_list
1296
1297 @staticmethod
1298 def agPad(testGen, opName, shapeList, dtype, error_name=None):
1299 arg_list = []
1300 rank = len(shapeList[0])
1301
1302 # Exhaustively test combinations of padding on each side of each dimension
1303 # - the range of padding values is defined by pad_min and pad_max
1304 # - for padding >9, the name format needs to be more distinctive
1305 pad_min, pad_max = 0, 1
1306 pad_values = [x for x in range(pad_min, pad_max + 1)]
1307 if error_name == ErrorIf.PadSmallerZero:
1308 pad_values = [x for x in range(-2, 0)]
1309 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
1310 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
1311
1312 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
1313 pad_const_int = testGen.getRandNumberDType(dtype)
1314 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00001315 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001316 pad_const_int = 0
1317 pad_const_fp = testGen.getRandNumberDType(dtype)
1318 else:
1319 return []
1320
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001321 list_shape_pad_values = list(shape_pad_values)
1322 # If we are producing tests for rank 6 or greater use sparsity
1323 if len(list_shape_pad_values) > 1024:
1324 sparsity_factor = 2 if error_name else 120
1325 sparsity = TosaArgGen._calculate_sparsity(
1326 len(list_shape_pad_values), sparsity_factor
1327 )
1328 else:
1329 sparsity = 1
1330
1331 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01001332 paddings = list(paddings)
1333 args_valid = True
1334
1335 if error_name == ErrorIf.PadSmallerZero:
1336 # Prevent negative output shapes while ensuring still testing for negative padding
1337 for i in range(rank):
1338 dim_after_padding = (
1339 paddings[i][0] + paddings[i][1] + shapeList[0][i]
1340 )
1341 if dim_after_padding < 1:
1342 paddings[i] = (0, 0)
1343 if all([p > -1 for p in paddings[i]]):
1344 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001345 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01001346 name = "pad"
1347 for r in range(rank):
1348 before, after = paddings[r]
1349 name = f"{name}{before}{after}"
1350 arg_list.append(
1351 (name, [np.array(paddings), pad_const_int, pad_const_fp])
1352 )
1353
1354 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
1355 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001356
1357 return arg_list
1358
1359 @staticmethod
1360 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
1361 arg_list = []
1362
1363 shape = shapeList[0]
1364 if error_name != ErrorIf.WrongRank:
1365 assert len(shape) == 4
1366
1367 # Generate comprehensive argument lists
1368 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
1369 paddings = {x for x in itertools.product(*([p_vals] * 4))}
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001370 # Stride must be greater than 1 to force non-integer error
1371 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
1372 s_vals = [x for x in range(startStride, testGen.args.max_pooling_stride + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001373 strides = {x for x in itertools.product(*([s_vals] * 2))}
1374 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)]
1375 kernels = {x for x in itertools.product(*([k_vals] * 2))}
1376
James Ward8b390432022-08-12 20:48:56 +01001377 if opName == "max_pool2d":
1378 accum_dtypes = [None] # max_pool has no accumulate dtype
1379 elif dtype == DType.INT8 or dtype == DType.INT16:
1380 accum_dtypes = [DType.INT32]
1381 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001382 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001383 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001384 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001385 elif error_name is None:
1386 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
1387 else:
1388 # Set to something for the ErrorIf case which has
1389 # incorrect input data-type
1390 accum_dtypes = [DType.INT32]
1391
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001392 if testGen.args.oversize:
1393 # add some oversize argument values
1394 bigStride = 7
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001395 strides.update(
1396 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
1397 )
1398 bigKernel = 9
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001399 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
1400 if max(shape) < 64:
1401 # padding must be less than the kernel size
1402 bigPadding = bigKernel - 1
1403 paddings.update(
1404 {x for x in itertools.product(*([[0, bigPadding]] * 4))}
1405 )
1406
1407 # There are too many parameter combinations, so generate them sparsely,
1408 # very sparse for negative tests
1409 sparsity_factor = 2 if error_name else 500
1410 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
1411
James Ward8b390432022-08-12 20:48:56 +01001412 arg_str = (
1413 "acc{}_st{}_kern{}_pad{}"
1414 if accum_dtypes[0] is not None
1415 else "st{}_kern{}_pad{}"
1416 )
1417
1418 def get_arg_list_element(accum, stride, pad, kern):
1419 # Return tuple containing the formatted argument string and
1420 # the corresponding argument values
1421 arg_str_elems = [
1422 "".join([str(x) for x in stride]),
1423 "".join([str(x) for x in kern]),
1424 "".join([str(x) for x in pad]),
1425 ]
1426 # Note: different order to string
1427 arg_val_elems = [stride, pad, kern]
1428
1429 if accum is not None:
1430 arg_str_elems.insert(0, testGen.typeStr(accum))
1431 arg_val_elems.insert(0, accum)
1432 return (arg_str.format(*arg_str_elems), arg_val_elems)
1433
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001434 n = 0
James Ward8b390432022-08-12 20:48:56 +01001435 for a in accum_dtypes:
1436 for s in sorted(list(strides)):
1437 for p in sorted(list(paddings)):
1438 for k in sorted(list(kernels)):
1439 if error_name in [
1440 ErrorIf.StrideSmallerOne,
1441 ErrorIf.KernelSmallerOne,
1442 ErrorIf.PadSmallerZero,
1443 ErrorIf.PadLargerEqualKernel,
1444 ]:
1445 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
1446 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001447 )
James Ward8b390432022-08-12 20:48:56 +01001448 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
1449 arg_vals = [a, sNew, pNew, kNew]
1450 arg_list.append(get_arg_list_element(*arg_vals))
1451 elif (
1452 n % sparsity == 0
1453 # padding must not exceed the kernel size
1454 and p[0] < k[0]
1455 and p[1] < k[0]
1456 and p[2] < k[1]
1457 and p[3] < k[1]
1458 # the padded shape must exceed the kernel size
1459 and (shape[1] + p[0] + p[1]) > k[0]
1460 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001461 ):
James Ward8b390432022-08-12 20:48:56 +01001462 remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0]
1463 remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1]
1464 if (
1465 # the parameters must produce integer exact output
1466 error_name != ErrorIf.PoolingOutputShapeNonInteger
1467 and remainder_h == 0
1468 and remainder_w == 0
1469 ) or (
1470 error_name == ErrorIf.PoolingOutputShapeNonInteger
1471 and (remainder_h != 0 or remainder_w != 0)
1472 ):
1473 arg_vals = [a, s, p, k]
1474 arg_list.append(get_arg_list_element(*arg_vals))
1475 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001476
1477 return arg_list
1478
1479 @staticmethod
1480 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
1481 arg_list = []
1482
1483 # Enumerate the output types here
1484 if error_name == ErrorIf.WrongOutputType:
1485 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
1486 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00001487 dtypeList = [
1488 DType.BOOL,
1489 DType.INT16,
1490 DType.INT32,
1491 DType.FP16,
1492 DType.BF16,
1493 DType.FP32,
1494 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001495 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00001496 dtypeList = [
1497 DType.BOOL,
1498 DType.INT8,
1499 DType.INT32,
1500 DType.FP16,
1501 DType.BF16,
1502 DType.FP32,
1503 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001504 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00001505 dtypeList = [
1506 DType.BOOL,
1507 DType.INT8,
1508 DType.INT16,
1509 DType.FP16,
1510 DType.BF16,
1511 DType.FP32,
1512 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001513 elif inDtype == DType.BOOL:
1514 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01001515 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00001516 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001517 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00001518 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001519 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00001520 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001521 elif error_name == ErrorIf.WrongInputType:
1522 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001523 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001524 else:
1525 raise Exception("Unexpected input dtype: {}".format(inDtype))
1526
1527 for dtype in dtypeList:
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001528 arg_list.append(("out{}".format(testGen.typeStr(dtype)), [dtype]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001529
1530 return arg_list
1531
1532 @staticmethod
1533 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
1534 arg_list = []
1535
1536 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001537 for outDtype in [
1538 DType.UINT8,
1539 DType.INT8,
1540 DType.INT16,
1541 DType.INT32,
1542 DType.UINT16,
1543 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001544 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001545 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001546 and error_name == ErrorIf.OutputZeroPointNotZero
1547 ):
1548 continue
1549 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001550 outDtype != DType.UINT16
1551 and error_name == ErrorIf.U16OutputZeroPointNotValid
1552 ) or (
1553 inDtype != DType.UINT16
1554 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001555 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001556 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001557 continue
1558 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001559 inDtype == DType.UINT8
1560 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001561 and error_name != ErrorIf.WrongOutputType
1562 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001563 # The only output dtypes for UINT8 are INT8/INT16, skip all others
1564 continue
1565 if (
1566 inDtype not in [DType.INT8, DType.INT16]
1567 and outDtype == DType.UINT8
1568 and error_name != ErrorIf.WrongOutputType
1569 ):
1570 # The only input dtypes for UINT8 are INT8/INT16, skip all others
1571 continue
1572 if (
1573 inDtype == DType.UINT16
1574 and outDtype != DType.INT16
1575 and error_name != ErrorIf.WrongOutputType
1576 ):
1577 # The only output dtype for UINT16 is INT16, skip all others
1578 continue
1579 if (
1580 inDtype != DType.INT16
1581 and outDtype == DType.UINT16
1582 and error_name != ErrorIf.WrongOutputType
1583 ):
1584 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001585 continue
1586 if (
1587 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001588 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001589 ):
1590 continue
1591
1592 for scale32 in [False, True]:
1593 if error_name == ErrorIf.ScaleTrue and not scale32:
1594 continue
1595 elif error_name == ErrorIf.ScaleNotTrue and scale32:
1596 continue
1597 for double_round in [False, True]:
1598 if error_name == ErrorIf.ScaleNotTrue and not double_round:
1599 continue
1600 for per_channel in [False, True]:
1601
1602 if (
1603 inDtype == DType.INT48
1604 and scale32
1605 and error_name != ErrorIf.ScaleTrue
1606 ):
1607 # Illegal condition. Must be scale32=False
1608 continue
1609 if (
1610 double_round
1611 and not scale32
1612 and error_name != ErrorIf.ScaleNotTrue
1613 ):
1614 # Illegal condition. ERROR_IF(!scale32 && double_round)
1615 continue
1616
1617 arg_list.append(
1618 (
1619 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001620 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001621 int(scale32),
1622 int(double_round),
1623 int(per_channel),
1624 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001625 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001626 )
1627 )
1628
1629 return arg_list
1630
1631 @staticmethod
1632 def agMul(testGen, opName, shapeList, dtype, error_name=None):
1633 arg_list = []
1634
1635 if dtype is DType.INT32:
1636 for p in range(testGen.args.num_rand_permutations):
1637
1638 shift = testGen.randInt(0, 32)
1639
1640 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
1641 else:
1642 arg_list.append(("perm0_shift0", [0]))
1643
1644 return arg_list
1645
1646 @staticmethod
1647 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
1648 arg_list = []
1649
1650 arg_list.append(("roundTrue", [True]))
1651 arg_list.append(("roundFalse", [False]))
1652
1653 return arg_list
1654
Luke Hutton57287132023-02-06 14:54:18 +00001655 @staticmethod
1656 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
1657 arg_list = []
1658
1659 arg_list.append(("inverseTrue", [True]))
1660 arg_list.append(("inverseFalse", [False]))
1661
1662 return arg_list
1663
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001664 # Helper function for reshape. Gets some factors of a larger number.
1665 @staticmethod
1666 def getFactors(val, start=1):
1667 factors = []
1668
1669 for i in range(start, int(np.sqrt(val)) + 1):
1670 if (val % i) == 0:
1671 factors.append(i)
1672
1673 return factors
1674
1675 @staticmethod
1676 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
1677 arg_list = []
1678
1679 origShape = shapeList[0]
1680
1681 totalElements = 1
1682 for s in origShape:
1683 totalElements *= s
1684
1685 # This code is NOT fast. Fortunately, the numbers are fairly small.
1686 factors = TosaArgGen.getFactors(totalElements)
1687
1688 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001689 # Rank from 1 to TOSA_TENSOR_MAX_RANK
1690 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001691 if len(factors) < newRank:
1692 continue
1693
1694 found = True
1695 # escape_counter breaks while loop if it continues on for too long
1696 escape_counter = 0
1697 while found:
1698 newShape = []
1699 # Generate newShape ensuring it isn't a duplicate
1700 remainingElements = totalElements
1701 shuffledFactors = testGen.rng.permutation(factors)
1702 for i in range(1, newRank):
1703 # pick rank-1 factors
1704 newShape.append(shuffledFactors[0])
1705 remainingElements = remainingElements // shuffledFactors[0]
1706 shuffledFactors = testGen.rng.permutation(
1707 TosaArgGen.getFactors(remainingElements)
1708 )
1709 newShape.append(remainingElements)
1710
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001711 # Check for duplicates
1712 found = False
1713 for name, other_shape in arg_list:
1714 if other_shape[0] == newShape:
1715 found = True
1716 break
1717
1718 escape_counter += 1
1719 if escape_counter >= 100:
1720 break
1721
1722 if not found:
1723 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
1724
1725 return arg_list
1726
1727 @staticmethod
1728 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
1729 arg_list = []
1730
1731 ifm_shape = shapeList[0]
1732
1733 if error_name == ErrorIf.IndexOutsideBounds:
1734 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
1735 incorrect_small_index = range(-len(ifm_shape), 0)
1736 permutations = [p for p in itertools.permutations(incorrect_large_index)]
1737 permutations.extend(
1738 [p for p in itertools.permutations(incorrect_small_index)]
1739 )
1740 elif error_name == ErrorIf.IndexUsedTwice:
1741 # Create list with a duplicated index
1742 perm_range = list(range(len(ifm_shape)))
1743 index_choice = testGen.rng.choice(range(len(perm_range)))
1744 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
1745 permutations = [p for p in itertools.permutations(perm_range)]
1746
1747 else:
1748 # Get all permutations
1749 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
1750
1751 # Limit to possible permutations from shape dimension or argument setting
1752 limit = min(len(permutations), testGen.args.num_rand_permutations)
1753
1754 # Get random permutation generator that uses all permutations
1755 random_permutations = testGen.rng.permutation(permutations)
1756
1757 # Create list of required amount of permutations
1758 arg_list = [
1759 ("perm{}".format(p), [random_permutations[p].tolist()])
1760 for p in range(limit)
1761 ]
1762 return arg_list
1763
1764 @staticmethod
1765 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
1766 arg_list = []
1767
1768 ifm_shape = shapeList[0]
1769 rank = len(ifm_shape)
1770
1771 for p in range(testGen.args.num_rand_permutations):
1772 start = []
1773 size = []
1774
1775 valid = True
1776
1777 for i in range(rank):
1778 if ifm_shape[i] > 1:
1779 start.append(testGen.randInt(0, ifm_shape[i]))
1780 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
1781
1782 # Invalid slice size?
1783 if size[i] == 0:
1784 valid = False
1785 else:
1786 start.append(0)
1787 size.append(1)
1788
1789 if valid:
1790 # If ERROR_IF test required then incorrect start, size will be returned
1791 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
1792 testGen, error_name, ifm_shape, start, size
1793 )
1794 arg_list.append(("perm{}".format(p), [start, size]))
1795 return arg_list
1796
1797 @staticmethod
1798 def agTile(testGen, opName, shapeList, dtype, error_name=None):
1799 arg_list = []
1800
1801 ifm_shape = shapeList[0]
1802 rank = len(ifm_shape)
1803
1804 for p in range(testGen.args.num_rand_permutations):
1805
1806 # Pick a few random, but small multiple values
1807 # because otherwise this has a tendency to generate
1808 # enormous tensors
1809 multiples = []
1810 for i in range(rank):
1811 if ifm_shape[i] > 1000:
1812 # Multiple of 1 if ifm_shape dimension is large to reduce
1813 # tensor size
1814 multiples.append(1)
1815 elif max(ifm_shape) > 1000:
1816 multiples.append(2)
1817 else:
1818 multiples.append(testGen.randInt(1, 4))
1819 arg_list.append(("perm{}".format(p), [multiples]))
1820
1821 return arg_list
1822
1823 @staticmethod
1824 def agResize(testGen, opName, shapeList, dtype, error_name=None):
1825 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001826 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001827
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001828 def get_aspect_ratio_resize_params():
1829 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
1830 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
1831 invert = testGen.rng.choice((False, True))
1832 letterbox = testGen.rng.choice((False, True))
1833
1834 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
1835 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
1836 scale_y_d = scale_x_d = 1
1837 offset_x = offset_y = 0
1838
1839 if letterbox:
1840 max_border = scale_y_n
1841 border_y = testGen.randInt(low=0, high=max_border)
1842 border_x = 0
1843 else:
1844 # Pillarboxing
1845 border_y = 0
1846 max_border = scale_x_n
1847 border_x = testGen.randInt(low=0, high=max_border)
1848
1849 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1850 offset = (offset_y, offset_x)
1851 border = (border_y, border_x)
1852
1853 return scale, offset, border
1854
1855 def get_upscale_downscale_params():
1856 valid_params = False
1857 while not valid_params:
1858 upscale = testGen.rng.choice((False, True))
1859
1860 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
1861 origin_sampling = testGen.rng.choice((False, True))
1862
1863 if upscale:
1864 shift = testGen.randInt(low=1, high=4)
1865 scale_x_d = scale_y_d = 1
1866 scale_x_n = scale_y_n = (
1867 1 << shift if origin_sampling else 2 << shift
1868 )
1869 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
1870 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
1871 else:
1872 scale_x_n = 1
1873 scale_y_n = 1
1874
1875 # Return list of valid scale_*_d values (max value 4) given input dim shape
1876 def get_valid_denom(ifm_dim):
1877 return [x for x in range(1, 5) if ifm_dim % x == 1]
1878
1879 # Generate list of valid downscale values and choose one randomly
1880 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
1881 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
1882
1883 if not valid_scale_y_ds and not valid_scale_x_ds:
1884 # Bad parameters, skip
1885 continue
1886
1887 if not valid_scale_y_ds:
1888 scale_y_d = 1
1889 else:
1890 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
1891
1892 if not valid_scale_x_ds:
1893 scale_x_d = 1
1894 else:
1895 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
1896
1897 border_x = border_y = 0
1898 offset_y = testGen.randInt(0, 16 * scale_y_n)
1899 offset_x = testGen.randInt(0, 16 * scale_x_n)
1900 valid_params = True
1901
1902 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1903 offset = (offset_y, offset_x)
1904 border = (border_y, border_x)
1905 return scale, offset, border
1906
1907 def get_rand_params():
1908 # Scale
1909 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
1910 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
1911
1912 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
1913 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
1914
1915 # Offsets and border within the scale
1916 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
1917 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
1918 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
1919 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
1920
1921 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1922 offset = (offset_y, offset_x)
1923 border = (border_y, border_x)
1924 return scale, offset, border
1925
1926 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001927 # Exclude illegal {mode, type} configurations. Pick legal output types
1928 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1929 outputDTypeList = [DType.INT8]
1930 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1931 outputDTypeList = [DType.INT16]
1932 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1933 outputDTypeList = [DType.INT32]
1934 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1935 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01001936 elif dtype == DType.FP16:
1937 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01001938 elif dtype == DType.BF16:
1939 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001940 elif dtype == DType.FP32:
1941 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001942 elif error_name == ErrorIf.WrongInputType:
1943 # If an incorrect input type is used then we set a 'correct'
1944 # output type to avoid other errors
1945 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
1946 else:
1947 continue
1948
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001949 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
1950
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001951 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001952 perm = 0
1953 while perm < testGen.args.num_rand_permutations:
1954 # Random choice of type of params we are testing
1955 _rnd_param_fn = testGen.rng.choice(
1956 (
1957 get_rand_params,
1958 get_upscale_downscale_params,
1959 get_aspect_ratio_resize_params,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001960 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001961 )
1962 scale, offset, border = _rnd_param_fn()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001963
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001964 # Expand params for bounds-checking
1965 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
1966 (offset_y, offset_x) = offset
1967 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001968
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001969 # Make sure output dimensions OH and OW are integers
1970 partial_output_y = (
1971 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
1972 )
1973 partial_output_x = (
1974 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
1975 )
1976 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
1977 if (
1978 partial_output_y % scale_y_d == 0
1979 and partial_output_x % scale_x_d == 0
1980 ):
1981 # Skip this test as it doesn't produce NonInteger output
1982 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001983 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001984 else:
1985 while partial_output_y % scale_y_d != 0:
1986 scale_y_d -= 1
1987 while partial_output_x % scale_x_d != 0:
1988 scale_x_d -= 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001989
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001990 output_y = partial_output_y // scale_y_d + 1
1991 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001992
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001993 if (
1994 output_y >= testGen.args.max_resize_output_dim
1995 or output_x >= testGen.args.max_resize_output_dim
1996 ) and error_name is None:
1997 # Skip positive test if output dim will be too high
1998 # Avoid high test latency and OOM issues
1999 perm += 1
2000 continue
2001
2002 if (
2003 output_y <= 0
2004 or output_y >= MAX_RESIZE_DIMENSION
2005 or output_x <= 0
2006 or output_x >= MAX_RESIZE_DIMENSION
2007 ):
2008 # Output dimensions out of scope
2009 if error_name is not None and perm > 0:
2010 # As long as we have one ERROR_IF test, don't worry
2011 # about creating all the other permutations
2012 perm += 1
2013 continue
2014
2015 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
2016 (
2017 output_y + scale_y_d >= MAX_RESIZE_DIMENSION
2018 and output_y - scale_y_d < 1
2019 )
2020 or (
2021 output_x + scale_x_d >= MAX_RESIZE_DIMENSION
2022 and output_x - scale_x_d < 1
2023 )
2024 ):
2025 # Can't create a negative test with these params as it
2026 # will create invalid output size
2027 if perm > 0:
2028 perm += 1
2029 continue
2030
2031 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
2032 offset = [offset_y, offset_x]
2033 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002034
2035 # Common for all data types
2036 if error_name is not None:
2037 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002038 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002039 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002040 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002041 outputDTypeNew,
2042 ) = TosaErrorIfArgGen.eiResizeErrorIf(
2043 testGen,
2044 error_name,
2045 mode,
2046 dtype,
2047 shapeList,
2048 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002049 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002050 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002051 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002052 )
2053 else:
2054 outputDTypeNew = outputDType
2055
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002056 arg_to_append = (
2057 arg_str.format(
2058 "N" if mode == ResizeMode.NEAREST else "B",
2059 testGen.typeStr(outputDTypeNew),
2060 scale[0],
2061 scale[1],
2062 scale[2],
2063 scale[3],
2064 offset[0],
2065 offset[1],
2066 border[0],
2067 border[1],
2068 ),
2069 [
2070 mode,
2071 scale,
2072 offset,
2073 border,
2074 dtype,
2075 outputDTypeNew,
2076 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002077 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002078 if arg_to_append in arg_list:
2079 # Skip already generated test params
2080 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002081
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002082 # Valid permutation
2083 perm += 1
2084 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002085 return arg_list
2086
2087 @staticmethod
2088 def agTable(testGen, opName, shapeList, dtype, error_name=None):
2089 arg_list = []
2090
2091 if dtype == DType.INT8:
2092 table = np.int32(
2093 testGen.rng.integers(low=-128, high=128, size=[256])
2094 ).tolist()
2095 else: # INT16
2096 table = np.int32(
2097 testGen.rng.integers(low=-32768, high=32768, size=[513])
2098 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07002099 # Make sure all slopes are within REQUIRE min/max 16-bit int
2100 for idx in range(len(table) - 1):
2101 slope = table[idx + 1] - table[idx]
2102 # Alter the next table entry to force the slope to be ok
2103 if slope > 32767:
2104 table[idx + 1] -= slope - 32767
2105 if slope < -32768:
2106 table[idx + 1] -= slope + 32768
2107 slope = table[idx + 1] - table[idx]
2108 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002109 arg_list.append(
2110 (
2111 "",
2112 [table],
2113 )
2114 )
2115 return arg_list
2116
2117 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
2118 # CondIf generates the condition values here.
2119 # Convert to tensors in the build function, along with the
2120 # then and else blocks
2121 arg_list = []
2122
2123 for c in [False, True]:
2124 arg_list.append(("cond{}".format(int(c)), [c]))
2125
2126 return arg_list
2127
2128 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
2129 # While loop: 0 iterations, 1, more than 1
2130 arg_list = []
2131
2132 for iter in [0, 1, 4]:
2133 arg_list.append(("iter{}".format(iter), [iter]))
2134
2135 return arg_list