blob: 8e00fab47f0ad39a5e3b93370d2e706f6e406749 [file] [log] [blame]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
5
6import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01007from generator.tosa_error_if import ErrorIf
8from generator.tosa_error_if import TosaErrorIfArgGen
9from serializer.tosa_serializer import DTypeNames
10from tosa.DType import DType
11from tosa.Op import Op
12from tosa.ResizeMode import ResizeMode
13
14# DTypeNames, DType, Op and ResizeMode are convenience variables to the
15# flatc-generated types that should be enums, but aren't
16
17
18class TosaQuantGen:
19 """QuantizedInfo random generator helper functions.
20
21 Specify with 'qgen': in the operator defintion.
22 """
23
24 def __init__(self):
25 pass
26
27 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000028 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010029
30 if dtype == DType.INT8:
31 return testGen.randInt(-128, 128)
32 elif dtype == DType.UINT8:
33 return testGen.randInt(0, 256)
34 elif error_name in [
35 ErrorIf.InputZeroPointNotZero,
36 ErrorIf.WeightZeroPointNotZero,
37 ErrorIf.OutputZeroPointNotZero,
38 ]:
39 zero_point = testGen.randInt(-128, 128)
40 if zero_point == 0:
41 zero_point = 1
42 return zero_point
43 return 0
44
45 @staticmethod
46 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010047 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000048 qinfo = [
49 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
50 TosaQuantGen.getZeroPoint(testGen, dtype),
51 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010052 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000053 qinfo = [
54 TosaQuantGen.getZeroPoint(testGen, dtype),
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010057 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000058 qinfo = [
59 TosaQuantGen.getZeroPoint(testGen, dtype),
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010062 return qinfo
63
64 @staticmethod
65 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010066 if isinstance(dtype_or_dtypeList, list):
67 # a list of [input, weights, accumulator] dtypes
68 dtypeList = dtype_or_dtypeList
69 else:
70 # an int, [input, weights, accumulator] dtypes are the same
71 dtypeList = [dtype_or_dtypeList] * 3
72
73 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000074 qinfo = [
75 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
76 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
77 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010078 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000079 qinfo = [
80 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
82 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010083 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000084 qinfo = [
85 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
87 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010088 return qinfo
89
90 @staticmethod
91 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010092 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000093 qinfo = [
94 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
95 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
96 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010097 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000098 qinfo = [
99 TosaQuantGen.getZeroPoint(testGen, dtype),
100 TosaQuantGen.getZeroPoint(testGen, dtype),
101 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100102 return qinfo
103
104 @staticmethod
105 def computeMultiplierAndShift(scaleFp, scale32):
106 # Derived from computeMultiplierAndShiftTosaScale32
107 # Provide a floating-point scaling factor and the scale32 parameter
108 # to compute the multiplier and shift
109
110 if scale32:
111 scaleBits = 31
112 else:
113 scaleBits = 15
114
115 m, shift = math.frexp(scaleFp)
116
117 if scaleFp < 0.0:
118 m = -m
119
120 multiplier = round(m * (1 << scaleBits))
121 assert multiplier <= (1 << scaleBits)
122
123 if multiplier == (1 << scaleBits):
124 multiplier = multiplier // 2
125 shift = shift + 1
126
127 shift = (-shift) + scaleBits
128 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
129 # scaleFp, scaleBits, m, multiplier, shift))
130
131 # Adjust multiplier such that shift is in allowed value range.
132 if shift == 0:
133 multiplier = multiplier // 4
134 shift = shift + 2
135 elif shift == 1:
136 multiplier = multiplier // 2
137 shift = shift + 1
138 elif shift == 63:
139 multiplier = multiplier * 2
140 shift = shift - 1
141
142 assert multiplier <= (1 << scaleBits)
143 assert shift >= 2 and shift <= 62
144
145 return multiplier, shift
146
147
148class TosaTensorGen:
149 """Tensor generators create a shape list for the placeholder and const tensor
150 data operands for the operator.
151
152 The actual random data is generated separately for each test.
153 """
154
155 def __init__(self):
156 pass
157
158 @staticmethod
159 def tgBasic(testGen, opName, rank, error_name=None):
160 pl, const = opName["operands"]
161 shape = testGen.makeShape(rank)
162
163 # Constrict the overall size of the shape when creating ERROR_IF tests
164 if error_name:
165 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
166
167 shape_list = []
168 for i in range(pl + const):
169 shape_list.append(shape.copy())
170
171 if error_name == ErrorIf.RankMismatch:
172 if rank == 1 and i != 1:
173 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
174 elif i != 1:
175 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
176
177 return shape_list
178
179 @staticmethod
180 def tgNHWC(testGen, opName, rank, error_name=None):
181 pl, const = opName["operands"]
182
183 if error_name != ErrorIf.WrongRank:
184 assert rank == 4
185
186 shape = testGen.makeShape(rank)
187
188 # Constrict the batch size?
189 if testGen.args.max_batch_size:
190 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
191
192 # Constrict the overall size of the shape when creating ERROR_IF tests
193 if error_name and error_name != ErrorIf.MaxDimExceeded:
194 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
195
196 shape_list = []
197 for i in range(pl + const):
198 shape_list.append(shape.copy())
199
200 return shape_list
201
202 @staticmethod
203 def tgScatter(testGen, opName, rank, error_name=None):
204 pl, const = opName["operands"]
205
206 assert pl == 2
207 assert const == 0
208 if error_name != ErrorIf.WrongRank:
209 assert rank == 3
210
211 values_in_shape = testGen.makeShape(rank)
212
213 # ignore max batch size if target shape is set
214 if testGen.args.max_batch_size and not testGen.args.target_shapes:
215 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
216
217 W = testGen.randInt(
218 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
219 )
220 # Constrict W if one dimension is too large to keep tensor size reasonable
221 if max(values_in_shape) > 5000:
222 W = testGen.randInt(0, 16)
223
224 input_shape = [values_in_shape[0], W, values_in_shape[2]]
225
226 shape_list = []
227 shape_list.append(values_in_shape.copy())
228 shape_list.append(input_shape.copy())
229
230 return shape_list
231
232 @staticmethod
233 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
234 shape = testGen.makeShape(rank)
235
236 pl, const = op["operands"]
237
238 shape_list = []
239
240 # Choose one of the inputs to broadcast
241 # Note: Simplifies OutputShaper code if we don't change first shape for errors
242 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
243 for i in range(pl + const):
244 shape_bcast = shape.copy()
245
246 # If the chosen input, pick a random index to broadcast
247 if i == bcast_idx:
248 fuzz_idx = testGen.randInt(0, rank)
249 if error_name == ErrorIf.DimensionMismatch:
250 shape_bcast[fuzz_idx] += 1
251 elif error_name == ErrorIf.RankMismatch:
252 # Add one rank to the shape (or more for rank of 1)
253 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
254 shape_bcast = np.concatenate(
255 (shape_bcast, testGen.makeShape(extra_ranks))
256 )
257 if rank != 1:
258 # Either keep the extra rank, or remove it
259 new_len = testGen.rng.choice([-2, len(shape_bcast)])
260 shape_bcast = shape_bcast[:new_len]
261 else:
262 shape_bcast[fuzz_idx] = 1
263
264 shape_list.append(shape_bcast)
265
266 return shape_list
267
268 @staticmethod
269 def tgConv2D(testGen, op, rank, error_name=None):
270 pl, const = op["operands"]
271
272 if error_name != ErrorIf.WrongRank:
273 assert rank == 4
274
275 # IFM dimensions are NHWC
276 ifm_shape = testGen.makeShape(rank)
277
278 # Constrict the batch size?
279 if testGen.args.max_batch_size:
280 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
281
282 # Constrict the overall size of the shape when creating ERROR_IF tests
283 if error_name:
284 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
285 ifm_shape, max_dim=24, max_items=10000
286 )
287
288 # Get the filter height/width from the operator parameters
289 filter_hw = op["filter"]
290
291 # Generate a random OFM depth
292 ofm_depth = testGen.makeShape(1)[0]
293
294 # The filter dimensions are OHWI
295 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
296
297 # The bias is OC
298 bias_shape = np.asarray([ofm_depth])
299
300 return [ifm_shape, filter_shape, bias_shape]
301
302 @staticmethod
303 def tgConv3D(testGen, op, rank, error_name=None):
304 pl, const = op["operands"]
305
306 if error_name != ErrorIf.WrongRank:
307 assert rank == 5
308
309 # IFM dimensions are NDHWC
310 ifm_shape = testGen.makeShape(rank)
311
312 # Constrict the batch size?
313 if testGen.args.max_batch_size:
314 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
315
316 # Constrict the overall size of the shape when creating ERROR_IF tests
317 if error_name:
318 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
319 ifm_shape, max_dim=24, max_items=10000
320 )
321
322 # Get the filter depth/height/width from the operator parameters
323 filter_dhw = op["filter"]
324
325 # Generate a random OFM channel
326 ofm_channel = testGen.makeShape(1)[0]
327
328 # The filter dimensions are ODHWI
329 filter_shape = np.asarray(
330 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
331 )
332
333 # The bias is OC
334 bias_shape = np.asarray([ofm_channel])
335
336 return [ifm_shape, filter_shape, bias_shape]
337
338 @staticmethod
339 def tgTransposeConv2D(testGen, op, rank, error_name=None):
340 pl, const = op["operands"]
341
342 if error_name != ErrorIf.WrongRank:
343 assert rank == 4
344
345 # IFM dimensions are NHWC
346 ifm_shape = testGen.makeShape(rank)
347
348 # Constrict the batch size?
349 if testGen.args.max_batch_size:
350 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
351
352 # Constrict the overall size of the shape when creating ERROR_IF tests
353 if error_name:
354 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
355 ifm_shape, max_dim=24, max_items=10000
356 )
357
358 # Get the filter height/width from the operator parameters
359 filter_hw = op["filter"]
360
361 # Generate a random OFM depth
362 ofm_depth = testGen.makeShape(1)[0]
363
364 # The filter dimensions are OHWI
365 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
366
367 # The bias is OC
368 bias_shape = np.asarray([ofm_depth])
369
370 return [ifm_shape, filter_shape, bias_shape]
371
372 @staticmethod
373 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
374 pl, const = op["operands"]
375
376 if error_name != ErrorIf.WrongRank:
377 assert rank == 4
378 assert pl == 1 and const == 2
379
380 # IFM dimensions are NHWC
381 ifm_shape = testGen.makeShape(rank)
382
383 # Constrict the batch size?
384 if testGen.args.max_batch_size:
385 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
386
387 # Constrict the overall size of the shape when creating ERROR_IF tests
388 if error_name:
389 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
390 ifm_shape, max_dim=24, max_items=10000
391 )
392
393 # Get the filter height/width from the operator parameters
394 # Filter is KH, HW, C, M
395 filter_hw = op["filter"]
396
397 # Generate a random OFM depth, but don't let it get too big because
398 # the output depth is M * C
399 filter_m = (
400 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
401 ) + 1
402
403 # The filter dimensions are HWCM
404 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
405
406 # The bias is M * C
407 bias_shape = np.asarray([ifm_shape[3] * filter_m])
408
409 return [ifm_shape, filter_shape, bias_shape]
410
411 @staticmethod
412 def tgFullyConnected(testGen, op, rank, error_name=None):
413 pl, const = op["operands"]
414
415 if error_name != ErrorIf.WrongRank:
416 assert rank == 2
417
418 input_shape = testGen.makeShape(rank)
419
420 # Constrict the overall size of the shape when creating ERROR_IF tests
421 if error_name:
422 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
423
424 filter_oc = testGen.rng.integers(
425 low=testGen.args.tensor_shape_range[0],
426 high=testGen.args.tensor_shape_range[1],
427 size=1,
428 )[0]
429 filter_shape = np.asarray([filter_oc, input_shape[1]])
430
431 bias_shape = np.asarray([filter_oc])
432
433 return [input_shape, filter_shape, bias_shape]
434
435 @staticmethod
436 def tgMatmul(testGen, op, rank, error_name=None):
437 pl, const = op["operands"]
438
439 if error_name != ErrorIf.WrongRank:
440 assert rank == 3
441 assert pl == 2 and const == 0
442
443 a_shape = testGen.makeShape(rank)
444
445 # Constrict the overall size of the shape when creating ERROR_IF tests
446 if error_name:
447 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
448
449 # Get a random number for b_oc even if target shape is defined
450 b_oc = np.int32(
451 testGen.rng.integers(
452 low=testGen.args.tensor_shape_range[0],
453 high=testGen.args.tensor_shape_range[1],
454 size=1,
455 )
456 )[0]
457 # If N or H is large let b_oc be 1 to reduce output tensor size
458 if max(a_shape) > 1000:
459 b_oc = 1
460
461 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
462 return [a_shape, b_shape]
463
464 @staticmethod
465 def tgConcat(testGen, opName, rank, error_name=None):
466 pl, const = opName["operands"]
467 shape = testGen.makeShape(rank)
468
469 # Create extra tensors to concat.
470 # Take into account value of pl when getting maximum number of concats
471 num_tensors = testGen.randInt(0, 4)
472 shape_list = []
473 for i in range(pl + const + num_tensors):
474 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
475 remove = testGen.rng.choice([True, False])
476 wrongShape = shape.copy()
477
478 if remove and len(shape) > 1:
479 wrongShape = wrongShape[1:]
480 else:
481 wrongShape = list(wrongShape)
482 wrongShape.append(testGen.rng.integers(1, 10))
483
484 shape_list.append(wrongShape)
485 else:
486 shape_list.append(shape.copy())
487
488 return shape_list
489
490 @staticmethod
491 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
492 if error_name in [
493 ErrorIf.AxisSmallerZero,
494 ErrorIf.AxisLargerRank,
495 ErrorIf.ConcatInputRankMismatch,
496 ]:
497 return shapeList
498
499 # Split concat shape along axis to allow for multiple const inputs
500 # without making too many large tensors
501 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
502 # If axis can't be split we still need to invalidate other dimensions
503 if error_name == ErrorIf.ConcatInputDimMismatch:
504 for shape in shapeList[1:]:
505 # Negative test shapeLists are created individually for each test,
506 # so no need to copy the shape before altering it.
507 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
508 return shapeList
509
510 # Create copy of shape we are going to split (so we don't alter shapeList)
511 shape = shapeList[0].copy()
512 # Add original shape as first input
513 new_shapeList = [shape.copy()]
514 length_on_axis = shape[axis]
515 remaining_length = length_on_axis
516 for i in range(len(shapeList) - 2):
517 # Calculate split on axis and remaining value
518 split_shape_val = int(shape[axis] / 2)
519 remaining_length = remaining_length - split_shape_val
520
521 # Append new shape, and set remaining shape
522 shape[axis] = split_shape_val
523 new_shapeList.append(shape.copy())
524
525 # invalidate dimensions
526 if error_name == ErrorIf.ConcatInputDimMismatch:
527 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
528 else:
529 shape[axis] = remaining_length
530
531 if i == len(shapeList) - 3:
532 new_shapeList.append(shape.copy())
533
534 return new_shapeList
535
536
537class TosaTensorValuesGen:
538 """Tensor Value generators create the random data for each test."""
539
540 def __init__(self):
541 pass
542
543 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000544 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100545 pCount, cCount = op["operands"]
546
547 tens = []
548 tens.extend(
549 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
550 )
551 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
552
553 return tens
554
555 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000556 def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100557 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100558 pCount, cCount = op["operands"]
559 assert (
560 pCount == 1 and cCount == 0
561 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100562 # Must create tensors with values within accumulator (int32) negatable
563 # range
564 max_val = (1 << 31) - 1
565 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100566 arr = np.int32(
567 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
568 )
569 placeholders = []
570 placeholders.append(
571 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
572 )
573 return placeholders
574 else:
575 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000576 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100577 )
578
579 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000580 def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100581 if dtypeList[0] == DType.INT32 and error_name is None:
582 # Make sure the operation does not cause value saturation - where
583 # the number wraps due to limited number of bits to store the answer
584 pCount, cCount = op["operands"]
585 assert (
586 pCount == 2 and cCount == 0
587 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
588 placeholders = []
589 add = op["op"] == Op.ADD
590 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
591 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
592 if add:
593 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
594 else:
595 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
596
597 # Work out the saturation limits
598 max_i32 = (1 << 31) - 1
599 min_i32 = -(1 << 31)
600 max_arr = np.full(shapeList[1], max_i32)
601 min_arr = np.full(shapeList[1], min_i32)
602
603 # Find how much values exceed the maximum/minimums
604 sat_max_arr = np.maximum(res_arr - max_arr, 0)
605 sat_min_arr = np.minimum(res_arr - min_arr, 0)
606
607 if not add:
608 # Swap saturation values and negate values as we need to perform opposite operations
609 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
610
611 # Create new array of unsaturated values by clipping values as needed
612 b_unsat_arr = b_arr
613 if (sat_max_arr != 0).any():
614 # Clip values that cause saturation
615 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
616 # Reduce axes in unsaturated tensor to match original tensor
617 for axis, dim in enumerate(b_arr.shape):
618 if dim != b_unsat_arr.shape[axis]:
619 assert (
620 dim == 1
621 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
622 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
623
624 if (sat_min_arr != 0).any():
625 # Clip values that cause saturation
626 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
627 # Reduce axes in unsaturated tensor to match original tensor
628 for axis, dim in enumerate(b_arr.shape):
629 if dim != b_unsat_arr.shape[axis]:
630 assert (
631 dim == 1
632 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
633 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
634
635 placeholders.append(
636 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
637 )
638 placeholders.append(
639 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
640 )
641
642 return placeholders
643 else:
644 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000645 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100646 )
647
648 @staticmethod
649 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000650 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100651 ):
652 if dtypeList[0] in (
653 DType.INT32,
654 DType.INT16,
655 DType.INT8,
656 ):
657 # Limit input tensors with cond_if_binary or while_loop to stop
658 # saturation of add/sub ops with int32 and keep all logical shift
659 # values between 0 to 31 for int16 or int8
660 pCount, cCount = op["operands"]
661 pRemain = pCount
662 placeholders = []
663 for idx, shape in enumerate(shapeList[:]):
664 if dtypeList[0] == DType.INT32:
665 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
666 else:
667 arr = np.int32(
668 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
669 )
670 if pRemain > 0:
671 placeholders.append(
672 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
673 )
674 pRemain -= 1
675 else:
676 placeholders.append(
677 testGen.ser.addConst(shape, dtypeList[idx], arr)
678 )
679
680 return placeholders
681 else:
682 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000683 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100684 )
685
686 @staticmethod
687 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000688 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100689 ):
690 pCount, cCount = op["operands"]
691 # Force value of operand[1] to be within [0, num_bits]
692 assert (
693 pCount == 2 and cCount == 0
694 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
695
696 placeholders = []
697 for idx, shape in enumerate(shapeList[:]):
698 if idx == 1:
699 if dtypeList[idx] == DType.INT8:
700 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
701 elif dtypeList[idx] == DType.INT16:
702 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
703 elif dtypeList[idx] == DType.INT32:
704 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
705 elif error_name == ErrorIf.WrongInputType:
706 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
707 else:
708 raise Exception("OpArithmeticRightShift: invalid input dtype")
709 else:
710 arr = testGen.getRandTensor(shape, dtypeList[idx])
711 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
712
713 return placeholders
714
715 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000716 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100717 # Set datatype of condition tensor to boolean
718 dtypeList[0] = DType.BOOL
719
720 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000721 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100722 )
723
724 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000725 def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100726 if error_name is None:
727 pCount, cCount = op["operands"]
728 assert (
729 pCount == 2 and cCount == 0
730 ), "Op.INTDIV must have 2 placeholders, 0 consts"
731
732 placeholders = []
733
734 # Two invalid cases for Op.INTDIV:
735 # 1. divisor == 0
736 # 2. dividend == -(1<<31) and divisor == -1
737 while True:
738 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
739 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
740
741 if (divisor_arr == 0).any():
742 continue
743
744 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
745 continue
746
747 break
748
749 placeholders.append(
750 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
751 )
752 placeholders.append(
753 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
754 )
755
756 return placeholders
757 else:
758 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000759 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100760 )
761
762 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000763 def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100764 if error_name is None:
765 pCount, cCount = op["operands"]
766 assert (
767 pCount == 2 and cCount == 0
768 ), "Op.MUL must have 2 placeholders, 0 consts"
769
770 tens = []
771 if dtypeList[0] == DType.FLOAT:
772 tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
773 else:
774 placeholders = []
775
776 # Make sure multiply result in int32 range
777 shift = testArgs[0]
778 if dtypeList[0] == DType.INT8:
779 num_bits = 8
780 elif dtypeList[0] == DType.INT16:
781 num_bits = 16
782 elif dtypeList[0] == DType.INT32:
783 num_bits = 32
784 elif error_name == ErrorIf.WrongInputType:
785 num_bits = 8
786 else:
787 raise Exception("OpMul: invalid input dtype")
788
789 for idx, shape in enumerate(shapeList[:]):
790 low = -(2 ** (num_bits - 1))
791 high = (2 ** (num_bits - 1)) - 1
792
793 a_arr = np.int32(
794 testGen.rng.integers(low=low, high=high, size=shapeList[0])
795 )
796 b_arr = np.int32(
797 testGen.rng.integers(low=low, high=high, size=shapeList[1])
798 )
799
800 i = 0
801 while True:
802
803 a_arr_64 = a_arr.astype(np.int64)
804 b_arr_64 = b_arr.astype(np.int64)
805
806 if shift > 0:
807 rounding = 1 << (shift - 1)
808 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
809 else:
810 result_arr = a_arr_64 * b_arr_64
811
812 if (result_arr > -(2**31)).all() and (
813 result_arr <= ((2**31) - 1)
814 ).all():
815 break
816
817 i = i + 1
818 a_arr = a_arr // 2
819 b_arr = b_arr // 2
820
821 placeholders.append(
822 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
823 )
824 placeholders.append(
825 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
826 )
827
828 tens.extend(placeholders)
829
830 return tens
831 else:
832 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000833 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100834 )
835
836 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000837 def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100838 count = len(shapeList) - testGen.args.num_const_inputs_concat
839 if count < 1:
840 count = 1
841 if testGen.args.num_const_inputs_concat == 0:
842 count = len(shapeList)
843
844 # Ensure axis is an int
845 testArgs[0] = int(testArgs[0])
846
847 shapeList = TosaTensorGen.tgConcatConstInput(
848 testGen, shapeList, testArgs[0], error_name
849 )
850
851 tens = []
852 tens.extend(
853 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
854 )
855 tens.extend(testGen.buildConstTensors(shapeList[count:], dtypeList[count:]))
856
857 return tens
858
859 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000860 def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100861 pCount, cCount = op["operands"]
862 assert (
863 pCount == 2 and cCount == 0
864 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
865 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
866 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
867 placeholders = []
868 placeholders.append(
869 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
870 )
871 placeholders.append(
872 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
873 )
874
875 return placeholders
876
877 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000878 def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100879 if error_name is None:
880 pCount, cCount = op["operands"]
881 assert (
882 pCount == 2 and cCount == 0
883 ), "Op.EQUAL must have 2 placeholders, 0 consts"
884 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
885 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
886 # Using random numbers means that it will be very unlikely that
887 # there are any matching (equal) values, therefore force that
888 # there are twice the number of matching values as the tensor rank
889 for num in range(0, len(shapeList[0]) * 2):
890 a_index = []
891 b_index = []
892 # Choose an index in each axis for the whole shape
893 for axis in range(0, len(shapeList[0])):
894 # Index can be up to the largest dimension in both shapes
895 index = np.int32(
896 testGen.rng.integers(
897 0, max(shapeList[0][axis], shapeList[1][axis])
898 )
899 )
900 # Reduce the index down to a shape's dim for broadcasting
901 a_index.append(min(shapeList[0][axis] - 1, index))
902 b_index.append(min(shapeList[1][axis] - 1, index))
903
904 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
905
906 placeholders = []
907 placeholders.append(
908 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
909 )
910 placeholders.append(
911 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
912 )
913 return placeholders
914 else:
915 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000916 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100917 )
918
919 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000920 def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100921 if dtypeList[0] == DType.INT32:
922 pCount, cCount = op["operands"]
923 assert (
924 pCount == 1 and cCount == 0
925 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
926 # Limit values so that the sum cannot exceed the range of an int32 during
927 # summation of any axis
928 range_val = int((1 << 31) / max(shapeList[0]))
929 values_arr = np.int32(
930 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
931 )
932 placeholders = []
933 placeholders.append(
934 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
935 )
936 return placeholders
937 else:
938 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000939 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100940 )
941
942
943class TosaArgGen:
944 """Argument generators create exhaustive or random lists of attributes for
945 operators that take attributes or other parameters.
946
947 The return value is a list of (descriptive_name, [arglist]) tuples where
948 the descriptive_name is appended to the test name and the arglist is expanded
949 as arguments to the operator build function.
950 """
951
952 def __init__(self):
953 pass
954
955 @staticmethod
956 def agNone(testGen, opName, shapeList, dtype, error_name=None):
957 """A trivial argument generator for operators that don't take any
958 non-tensor arguments"""
959 return [("", [])]
960
961 @staticmethod
962 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
963 """Build the axis argument for operators that take a single axis"""
964 axes = []
965 shape = shapeList[0]
966
967 if error_name == ErrorIf.AxisSmallerZero:
968 small_axis = testGen.rng.integers(-5, 0)
969 axes.append(("axis{}".format(small_axis), [small_axis]))
970 elif error_name == ErrorIf.AxisLargerRank:
971 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
972 axes.append(("axis{}".format(large_axis), [large_axis]))
973 else:
974 for a in range(0, len(shape)):
975 axes.append(("axis{}".format(a), [a]))
976
977 return axes
978
979 @staticmethod
980 def agConv(testGen, opName, shapeList, dtype, error_name=None):
981 arg_list = []
982
983 ifm_shape = shapeList[0]
984 filter_shape = shapeList[1]
985 # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
986 k = [int(x) for x in opName.split("_")[-1].split("x")]
987
988 # Check the rank
989 rank = 5 if opName.startswith("conv3d") else 4
990 if error_name != ErrorIf.WrongRank:
991 assert len(ifm_shape) == rank
992 assert len(filter_shape) == rank
993
994 # kernel rank omits batch and channels
995 k_rank = rank - 2
996 assert len(k) == k_rank
997
998 # Generate comprehensive argument lists
999 # - except for named errors, which use specific invalid value(s)
1000 if error_name == ErrorIf.PadSmallerZero:
1001 p_vals = [testGen.rng.choice(range(-5, 0))]
1002 else:
1003 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1004 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1005 if error_name == ErrorIf.StrideSmallerOne:
1006 # Can't use stride=0, as it is used to derive output shape, as a divisor
1007 s_vals = [testGen.rng.choice(range(-5, 0))]
1008 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001009 # Stride must be greater than 1 to force non-integer error
1010 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
1011 s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001012 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1013 if error_name == ErrorIf.DilationSmallerOne:
1014 d_vals = [testGen.rng.choice(range(-5, 1))]
1015 else:
1016 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1017 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
1018
1019 if not error_name and testGen.args.oversize:
1020 # add some oversize argument values
1021 if max(ifm_shape) < 64:
1022 bigPadding = 9
1023 paddings.update(
1024 {x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))}
1025 )
1026 bigStride = 8
1027 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
1028 bigDilation = 7
1029 dilations.update(
1030 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1031 )
1032
1033 # There are too many parameter combinations, so generate them sparsely,
1034 # very sparse for negative tests
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001035 sparsity_factor = 2 if error_name else 120
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001036 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
1037 # If there are only a small number of tests, just select them all
1038 if sparsity < 13:
1039 sparsity = 1
1040 # To get a variety of parameter combinations sparsity should not be a
1041 # multiple of 2, 3 or 5
1042 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1043 sparsity += 1
1044
1045 n = 0
1046 for s in sorted(list(strides)):
1047 for p in sorted(list(paddings)):
1048 for d in sorted(list(dilations)):
1049 if (
1050 n % sparsity == 0
1051 # padding must not exceed the kernel size ?
1052 # and p[0] < k[0] and p[1] < k[0]
1053 # and p[2] < k[1] and p[3] < k[1]
1054 # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
1055 # the padded shape must exceed the kernel size
1056 and (ifm_shape[1] + p[0] + p[1]) > k[0]
1057 and (ifm_shape[2] + p[2] + p[3]) > k[1]
1058 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
1059 # the padded shape must exceed the dilation
1060 and (ifm_shape[1] + p[0] + p[1]) > d[0]
1061 and (ifm_shape[2] + p[2] + p[3]) > d[1]
1062 and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
1063 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001064 remainders = []
1065 for index in range(k_rank):
1066 pad_offset = index * 2
1067 remainders.append(
1068 (
1069 ifm_shape[index + 1]
1070 - 1
1071 + p[pad_offset]
1072 + p[pad_offset + 1]
1073 - (k[index] - 1) * d[index]
1074 )
1075 % s[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001076 )
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001077 if (
1078 # the parameters must produce integer exact output
1079 error_name != ErrorIf.ConvOutputShapeNonInteger
1080 and max(remainders) == 0
1081 ) or (
1082 error_name == ErrorIf.ConvOutputShapeNonInteger
1083 and max(remainders) > 0
1084 ):
1085 arg_list.append(
1086 (
1087 "st{}_pad{}_dilat{}".format(
1088 "".join([str(x) for x in s]),
1089 "".join([str(x) for x in p]),
1090 "".join([str(x) for x in d]),
1091 ),
1092 [s, p, d],
1093 )
1094 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001095 n += 1
1096
1097 return arg_list
1098
1099 @staticmethod
1100 def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
1101 arg_list = []
1102
1103 ifm_shape = shapeList[0]
1104 filter_shape = shapeList[1]
1105
1106 # Must be rank 4
1107 if error_name != ErrorIf.WrongRank:
1108 assert len(ifm_shape) == 4
1109 assert len(filter_shape) == 4
1110
1111 # Generate comprehensive argument lists
1112 # - except for named errors, which use specific invalid value(s)
1113 if error_name == ErrorIf.PadSmallerZero:
1114 p_vals = [testGen.rng.choice(range(-5, 0))]
1115 else:
1116 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001117 paddings = {x for x in itertools.product(*([p_vals] * 4))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001118 if error_name == ErrorIf.StrideSmallerOne:
1119 # Can't use stride=0, as it is used to derive output shape, as a divisor
1120 s_vals = [testGen.rng.choice(range(-5, 0))]
1121 else:
1122 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
1123 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001124
Jeremy Johnson5860df62022-05-04 15:30:58 +01001125 if not error_name and testGen.args.oversize:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001126 # add some oversize argument values
1127 if max(ifm_shape) < 64:
1128 bigPadding = 9
1129 paddings.update(
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001130 {x for x in itertools.product(*([[0, bigPadding]] * 4))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001131 )
1132 bigStride = 8
1133 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001134
1135 # There are too many parameter combinations, so generate them sparsely,
1136 # very sparse for negative tests
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001137 sparsity_factor = 2 if error_name else 10
TatWai Chong24594f52022-06-08 00:48:04 -07001138 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001139 # If there are only a small number of tests, just select them all
1140 if sparsity < 13:
1141 sparsity = 1
1142 # To get a variety of parameter combinations sparsity should not be a
1143 # multiple of 2, 3 or 5
1144 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1145 sparsity += 1
1146
1147 n = 0
1148 for s in sorted(list(strides)):
1149 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07001150 if n % sparsity == 0:
1151 # Determine the output shape
1152 oh = (ifm_shape[1] - 1) * s[0] - p[0] - p[1] + filter_shape[1]
1153 ow = (ifm_shape[2] - 1) * s[1] - p[2] - p[3] + filter_shape[2]
1154 os = [ifm_shape[0], oh, ow, filter_shape[0]]
1155 arg_list.append(
1156 (
1157 "st{}_pad{}_os{}".format(
1158 "".join([str(x) for x in s]),
1159 "".join([str(x) for x in p]),
1160 "x".join([str(x) for x in os]),
1161 ),
1162 [s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001163 )
TatWai Chong24594f52022-06-08 00:48:04 -07001164 )
1165 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001166
1167 return arg_list
1168
1169 @staticmethod
1170 def agPad(testGen, opName, shapeList, dtype, error_name=None):
1171 arg_list = []
1172 rank = len(shapeList[0])
1173
1174 # Exhaustively test combinations of padding on each side of each dimension
1175 # - the range of padding values is defined by pad_min and pad_max
1176 # - for padding >9, the name format needs to be more distinctive
1177 pad_min, pad_max = 0, 1
1178 pad_values = [x for x in range(pad_min, pad_max + 1)]
1179 if error_name == ErrorIf.PadSmallerZero:
1180 pad_values = [x for x in range(-2, 0)]
1181 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
1182 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
1183
1184 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
1185 pad_const_int = testGen.getRandNumberDType(dtype)
1186 pad_const_fp = 0
1187 elif dtype == DType.FLOAT:
1188 pad_const_int = 0
1189 pad_const_fp = testGen.getRandNumberDType(dtype)
1190 else:
1191 return []
1192
1193 for paddings in shape_pad_values:
1194 name = "pad"
1195 for r in range(rank):
1196 before, after = paddings[r]
1197 name = f"{name}{before}{after}"
1198 arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp]))
1199
1200 return arg_list
1201
1202 @staticmethod
1203 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
1204 arg_list = []
1205
1206 shape = shapeList[0]
1207 if error_name != ErrorIf.WrongRank:
1208 assert len(shape) == 4
1209
1210 # Generate comprehensive argument lists
1211 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
1212 paddings = {x for x in itertools.product(*([p_vals] * 4))}
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001213 # Stride must be greater than 1 to force non-integer error
1214 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
1215 s_vals = [x for x in range(startStride, testGen.args.max_pooling_stride + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001216 strides = {x for x in itertools.product(*([s_vals] * 2))}
1217 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)]
1218 kernels = {x for x in itertools.product(*([k_vals] * 2))}
1219
1220 if testGen.args.oversize:
1221 # add some oversize argument values
1222 bigStride = 7
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001223 strides.update(
1224 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
1225 )
1226 bigKernel = 9
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001227 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
1228 if max(shape) < 64:
1229 # padding must be less than the kernel size
1230 bigPadding = bigKernel - 1
1231 paddings.update(
1232 {x for x in itertools.product(*([[0, bigPadding]] * 4))}
1233 )
1234
1235 # There are too many parameter combinations, so generate them sparsely,
1236 # very sparse for negative tests
1237 sparsity_factor = 2 if error_name else 500
1238 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
1239
1240 n = 0
1241 for s in sorted(list(strides)):
1242 for p in sorted(list(paddings)):
1243 for k in sorted(list(kernels)):
1244 if error_name in [
1245 ErrorIf.StrideSmallerOne,
1246 ErrorIf.KernelSmallerOne,
1247 ErrorIf.PadSmallerZero,
1248 ErrorIf.PadLargerEqualKernel,
1249 ]:
1250 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
1251 testGen, error_name, s, p, k
1252 )
1253 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
1254 arg_list.append(
1255 (
1256 "st{}_kern{}_pad{}".format(
1257 "".join([str(x) for x in sNew]),
1258 "".join([str(x) for x in kNew]),
1259 "".join([str(x) for x in pNew]),
1260 ),
1261 [sNew, pNew, kNew],
1262 )
1263 )
1264 elif (
1265 n % sparsity == 0
1266 # padding must not exceed the kernel size
1267 and p[0] < k[0]
1268 and p[1] < k[0]
1269 and p[2] < k[1]
1270 and p[3] < k[1]
1271 # the padded shape must exceed the kernel size
1272 and (shape[1] + p[0] + p[1]) > k[0]
1273 and (shape[2] + p[2] + p[3]) > k[1]
1274 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001275 remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0]
1276 remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1]
1277 if (
1278 # the parameters must produce integer exact output
1279 error_name != ErrorIf.PoolingOutputShapeNonInteger
1280 and remainder_h == 0
1281 and remainder_w == 0
1282 ) or (
1283 error_name == ErrorIf.PoolingOutputShapeNonInteger
1284 and (remainder_h != 0 or remainder_w != 0)
1285 ):
1286 arg_list.append(
1287 (
1288 "st{}_kern{}_pad{}".format(
1289 "".join([str(x) for x in s]),
1290 "".join([str(x) for x in k]),
1291 "".join([str(x) for x in p]),
1292 ),
1293 [s, p, k],
1294 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001295 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001296 n += 1
1297
1298 return arg_list
1299
1300 @staticmethod
1301 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
1302 arg_list = []
1303
1304 # Enumerate the output types here
1305 if error_name == ErrorIf.WrongOutputType:
1306 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
1307 elif inDtype == DType.INT8:
1308 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
1309 elif inDtype == DType.INT16:
1310 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
1311 elif inDtype == DType.INT32:
1312 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
1313 elif inDtype == DType.BOOL:
1314 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
1315 elif inDtype == DType.FLOAT:
1316 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
1317 elif error_name == ErrorIf.WrongInputType:
1318 # Pick some potentially correct output type for incorrect input type
1319 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
1320 else:
1321 raise Exception("Unexpected input dtype: {}".format(inDtype))
1322
1323 for dtype in dtypeList:
1324 arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
1325
1326 return arg_list
1327
1328 @staticmethod
1329 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
1330 arg_list = []
1331
1332 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001333 for outDtype in [
1334 DType.UINT8,
1335 DType.INT8,
1336 DType.INT16,
1337 DType.INT32,
1338 DType.UINT16,
1339 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001340 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001341 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001342 and error_name == ErrorIf.OutputZeroPointNotZero
1343 ):
1344 continue
1345 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001346 outDtype != DType.UINT16
1347 and error_name == ErrorIf.U16OutputZeroPointNotValid
1348 ) or (
1349 inDtype != DType.UINT16
1350 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001351 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001352 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001353 continue
1354 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001355 inDtype == DType.UINT8
1356 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001357 and error_name != ErrorIf.WrongOutputType
1358 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001359 # The only output dtypes for UINT8 are INT8/INT16, skip all others
1360 continue
1361 if (
1362 inDtype not in [DType.INT8, DType.INT16]
1363 and outDtype == DType.UINT8
1364 and error_name != ErrorIf.WrongOutputType
1365 ):
1366 # The only input dtypes for UINT8 are INT8/INT16, skip all others
1367 continue
1368 if (
1369 inDtype == DType.UINT16
1370 and outDtype != DType.INT16
1371 and error_name != ErrorIf.WrongOutputType
1372 ):
1373 # The only output dtype for UINT16 is INT16, skip all others
1374 continue
1375 if (
1376 inDtype != DType.INT16
1377 and outDtype == DType.UINT16
1378 and error_name != ErrorIf.WrongOutputType
1379 ):
1380 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001381 continue
1382 if (
1383 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001384 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001385 ):
1386 continue
1387
1388 for scale32 in [False, True]:
1389 if error_name == ErrorIf.ScaleTrue and not scale32:
1390 continue
1391 elif error_name == ErrorIf.ScaleNotTrue and scale32:
1392 continue
1393 for double_round in [False, True]:
1394 if error_name == ErrorIf.ScaleNotTrue and not double_round:
1395 continue
1396 for per_channel in [False, True]:
1397
1398 if (
1399 inDtype == DType.INT48
1400 and scale32
1401 and error_name != ErrorIf.ScaleTrue
1402 ):
1403 # Illegal condition. Must be scale32=False
1404 continue
1405 if (
1406 double_round
1407 and not scale32
1408 and error_name != ErrorIf.ScaleNotTrue
1409 ):
1410 # Illegal condition. ERROR_IF(!scale32 && double_round)
1411 continue
1412
1413 arg_list.append(
1414 (
1415 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001416 DTypeNames[outDtype],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001417 int(scale32),
1418 int(double_round),
1419 int(per_channel),
1420 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001421 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001422 )
1423 )
1424
1425 return arg_list
1426
1427 @staticmethod
1428 def agMul(testGen, opName, shapeList, dtype, error_name=None):
1429 arg_list = []
1430
1431 if dtype is DType.INT32:
1432 for p in range(testGen.args.num_rand_permutations):
1433
1434 shift = testGen.randInt(0, 32)
1435
1436 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
1437 else:
1438 arg_list.append(("perm0_shift0", [0]))
1439
1440 return arg_list
1441
1442 @staticmethod
1443 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
1444 arg_list = []
1445
1446 arg_list.append(("roundTrue", [True]))
1447 arg_list.append(("roundFalse", [False]))
1448
1449 return arg_list
1450
1451 # Helper function for reshape. Gets some factors of a larger number.
1452 @staticmethod
1453 def getFactors(val, start=1):
1454 factors = []
1455
1456 for i in range(start, int(np.sqrt(val)) + 1):
1457 if (val % i) == 0:
1458 factors.append(i)
1459
1460 return factors
1461
1462 @staticmethod
1463 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
1464 arg_list = []
1465
1466 origShape = shapeList[0]
1467
1468 totalElements = 1
1469 for s in origShape:
1470 totalElements *= s
1471
1472 # This code is NOT fast. Fortunately, the numbers are fairly small.
1473 factors = TosaArgGen.getFactors(totalElements)
1474
1475 for p in range(testGen.args.num_rand_permutations):
1476 newRank = testGen.randInt(1, 7)
1477 if len(factors) < newRank:
1478 continue
1479
1480 found = True
1481 # escape_counter breaks while loop if it continues on for too long
1482 escape_counter = 0
1483 while found:
1484 newShape = []
1485 # Generate newShape ensuring it isn't a duplicate
1486 remainingElements = totalElements
1487 shuffledFactors = testGen.rng.permutation(factors)
1488 for i in range(1, newRank):
1489 # pick rank-1 factors
1490 newShape.append(shuffledFactors[0])
1491 remainingElements = remainingElements // shuffledFactors[0]
1492 shuffledFactors = testGen.rng.permutation(
1493 TosaArgGen.getFactors(remainingElements)
1494 )
1495 newShape.append(remainingElements)
1496
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001497 # Check for duplicates
1498 found = False
1499 for name, other_shape in arg_list:
1500 if other_shape[0] == newShape:
1501 found = True
1502 break
1503
1504 escape_counter += 1
1505 if escape_counter >= 100:
1506 break
1507
1508 if not found:
1509 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
1510
1511 return arg_list
1512
1513 @staticmethod
1514 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
1515 arg_list = []
1516
1517 ifm_shape = shapeList[0]
1518
1519 if error_name == ErrorIf.IndexOutsideBounds:
1520 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
1521 incorrect_small_index = range(-len(ifm_shape), 0)
1522 permutations = [p for p in itertools.permutations(incorrect_large_index)]
1523 permutations.extend(
1524 [p for p in itertools.permutations(incorrect_small_index)]
1525 )
1526 elif error_name == ErrorIf.IndexUsedTwice:
1527 # Create list with a duplicated index
1528 perm_range = list(range(len(ifm_shape)))
1529 index_choice = testGen.rng.choice(range(len(perm_range)))
1530 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
1531 permutations = [p for p in itertools.permutations(perm_range)]
1532
1533 else:
1534 # Get all permutations
1535 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
1536
1537 # Limit to possible permutations from shape dimension or argument setting
1538 limit = min(len(permutations), testGen.args.num_rand_permutations)
1539
1540 # Get random permutation generator that uses all permutations
1541 random_permutations = testGen.rng.permutation(permutations)
1542
1543 # Create list of required amount of permutations
1544 arg_list = [
1545 ("perm{}".format(p), [random_permutations[p].tolist()])
1546 for p in range(limit)
1547 ]
1548 return arg_list
1549
1550 @staticmethod
1551 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
1552 arg_list = []
1553
1554 ifm_shape = shapeList[0]
1555 rank = len(ifm_shape)
1556
1557 for p in range(testGen.args.num_rand_permutations):
1558 start = []
1559 size = []
1560
1561 valid = True
1562
1563 for i in range(rank):
1564 if ifm_shape[i] > 1:
1565 start.append(testGen.randInt(0, ifm_shape[i]))
1566 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
1567
1568 # Invalid slice size?
1569 if size[i] == 0:
1570 valid = False
1571 else:
1572 start.append(0)
1573 size.append(1)
1574
1575 if valid:
1576 # If ERROR_IF test required then incorrect start, size will be returned
1577 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
1578 testGen, error_name, ifm_shape, start, size
1579 )
1580 arg_list.append(("perm{}".format(p), [start, size]))
1581 return arg_list
1582
1583 @staticmethod
1584 def agTile(testGen, opName, shapeList, dtype, error_name=None):
1585 arg_list = []
1586
1587 ifm_shape = shapeList[0]
1588 rank = len(ifm_shape)
1589
1590 for p in range(testGen.args.num_rand_permutations):
1591
1592 # Pick a few random, but small multiple values
1593 # because otherwise this has a tendency to generate
1594 # enormous tensors
1595 multiples = []
1596 for i in range(rank):
1597 if ifm_shape[i] > 1000:
1598 # Multiple of 1 if ifm_shape dimension is large to reduce
1599 # tensor size
1600 multiples.append(1)
1601 elif max(ifm_shape) > 1000:
1602 multiples.append(2)
1603 else:
1604 multiples.append(testGen.randInt(1, 4))
1605 arg_list.append(("perm{}".format(p), [multiples]))
1606
1607 return arg_list
1608
1609 @staticmethod
1610 def agResize(testGen, opName, shapeList, dtype, error_name=None):
1611 arg_list = []
1612
1613 ifm_shape = shapeList[0]
1614 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
1615
1616 # Exclude illegal {mode, type} configurations. Pick legal output types
1617 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1618 outputDTypeList = [DType.INT8]
1619 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1620 outputDTypeList = [DType.INT16]
1621 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1622 outputDTypeList = [DType.INT32]
1623 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1624 outputDTypeList = [DType.INT48]
1625 elif dtype == DType.FLOAT:
1626 outputDTypeList = [DType.FLOAT]
1627 elif error_name == ErrorIf.WrongInputType:
1628 # If an incorrect input type is used then we set a 'correct'
1629 # output type to avoid other errors
1630 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
1631 else:
1632 continue
1633
1634 for outputDType in outputDTypeList:
1635 for perm in range(testGen.args.num_rand_permutations):
1636 # Randomly generate legal output dimensions and shift
1637 # and then compute the stride and offset based on them
1638 # A output_dim of 1 will cause offset to exceed allowed range
1639 # so minimum value 2 produced below
1640 output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
1641 while (float(ifm_shape[1]) / float(output_dims[0])) >= 16:
1642 output_dims[0] += 1
1643 while (float(ifm_shape[2]) / float(output_dims[1])) >= 16:
1644 output_dims[1] += 1
1645
1646 in_center_h = (ifm_shape[1] - 1) / 2.0
1647 in_center_w = (ifm_shape[2] - 1) / 2.0
1648 out_center_h = (output_dims[0] - 1) / 2.0
1649 out_center_w = (output_dims[1] - 1) / 2.0
1650
1651 fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
1652 fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
1653 fp_offset_y = in_center_h - fp_stride_y * out_center_h
1654 fp_offset_x = in_center_w - fp_stride_x * out_center_w
1655
1656 if outputDType == DType.FLOAT:
1657 float_op = True
1658 arg_str = (
1659 "mode{}_shift{}_odim{}x{}_out{}"
1660 "_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}"
1661 )
1662 shift = 0
1663 stride = [0, 0]
1664 offset = [0, 0]
1665 stride_fp = [fp_stride_y, fp_stride_x]
1666 offset_fp = [fp_offset_y, fp_offset_x]
1667
1668 else:
1669 float_op = False
1670 arg_str = "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}"
1671 shift = testGen.randInt(1, 12)
1672 # Now search for a shift value (1 to 11) that will produce
1673 # a valid and predictable resize operation
1674 count = 0
1675 while count < 12:
1676 unit = float(1 << shift)
1677 stride_y = int(round(fp_stride_y * unit))
1678 stride_x = int(round(fp_stride_x * unit))
1679 offset_y = int(round(fp_offset_y * unit))
1680 offset_x = int(round(fp_offset_x * unit))
1681
1682 if (
1683 stride_y <= 0
1684 or stride_x <= 0
1685 or stride_y >= (16 << shift)
1686 or stride_x >= (16 << shift)
1687 or offset_y >= (16 << shift)
1688 or offset_x >= (16 << shift)
1689 or offset_y <= (-16 << shift)
1690 or offset_x <= (-16 << shift)
1691 ):
1692 # Change the shift value and check again
1693 count += 1
1694 shift = (shift % 11) + 1
1695 continue
1696
1697 def RESIZE_REQUIRE_CALC(
1698 length_in, length_out, stride, offset, shift
1699 ):
1700 # Perform the pseudo loop to look for out of bounds
1701 for pos in range(0, length_out):
1702 a = pos * stride + offset
1703 ia = a >> shift
1704 ia0 = max(ia, 0)
1705 ia1 = min(ia + 1, length_in - 1)
1706 if ia0 > ia1:
1707 # Found a problem value
1708 break
1709 return ia0, ia1
1710
1711 iy0, iy1 = RESIZE_REQUIRE_CALC(
1712 ifm_shape[1], output_dims[0], stride_y, offset_y, shift
1713 )
1714 ix0, ix1 = RESIZE_REQUIRE_CALC(
1715 ifm_shape[2], output_dims[1], stride_x, offset_x, shift
1716 )
1717 if ix0 > ix1 or iy0 > iy1:
1718 # Change the shift value and check again
1719 count += 1
1720 shift = (shift % 11) + 1
1721 continue
1722 break
1723
1724 if count >= 12:
1725 # Couldn't find a good set of values for this test, skip it
1726 continue
1727
1728 stride = [stride_y, stride_x]
1729 offset = [offset_y, offset_x]
1730
1731 stride_fp = [0.0, 0.0]
1732 offset_fp = [0.0, 0.0]
1733
1734 # Common for all data types
1735 if error_name is not None:
1736 (
1737 shift,
1738 stride,
1739 stride_fp,
1740 offset,
1741 offset_fp,
1742 outputDTypeNew,
1743 ) = TosaErrorIfArgGen.eiResizeErrorIf(
1744 testGen,
1745 error_name,
1746 mode,
1747 dtype,
1748 shapeList,
1749 outputDType,
1750 shift,
1751 stride,
1752 stride_fp,
1753 offset,
1754 offset_fp,
1755 )
1756 else:
1757 outputDTypeNew = outputDType
1758
1759 arg_list.append(
1760 (
1761 arg_str.format(
1762 "N" if mode == ResizeMode.NEAREST else "B",
1763 shift,
1764 output_dims[0],
1765 output_dims[1],
1766 testGen.typeStr(outputDTypeNew),
1767 stride_fp[0] if float_op else stride[0],
1768 stride_fp[1] if float_op else stride[1],
1769 offset_fp[0] if float_op else offset[0],
1770 offset_fp[1] if float_op else offset[1],
1771 ),
1772 [
1773 mode,
1774 stride,
1775 offset,
1776 shift,
1777 stride_fp,
1778 offset_fp,
1779 output_dims,
1780 dtype,
1781 outputDTypeNew,
1782 ],
1783 )
1784 )
1785
1786 return arg_list
1787
1788 @staticmethod
1789 def agTable(testGen, opName, shapeList, dtype, error_name=None):
1790 arg_list = []
1791
1792 if dtype == DType.INT8:
1793 table = np.int32(
1794 testGen.rng.integers(low=-128, high=128, size=[256])
1795 ).tolist()
1796 else: # INT16
1797 table = np.int32(
1798 testGen.rng.integers(low=-32768, high=32768, size=[513])
1799 ).tolist()
1800
1801 arg_list.append(
1802 (
1803 "",
1804 [table],
1805 )
1806 )
1807 return arg_list
1808
1809 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
1810 # CondIf generates the condition values here.
1811 # Convert to tensors in the build function, along with the
1812 # then and else blocks
1813 arg_list = []
1814
1815 for c in [False, True]:
1816 arg_list.append(("cond{}".format(int(c)), [c]))
1817
1818 return arg_list
1819
1820 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
1821 # While loop: 0 iterations, 1, more than 1
1822 arg_list = []
1823
1824 for iter in [0, 1, 4]:
1825 arg_list.append(("iter{}".format(iter), [iter]))
1826
1827 return arg_list