blob: fed91f6f393396e99613ada1e382c88a3c76f8b9 [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_error_if import ErrorIf
9from generator.tosa_error_if import TosaErrorIfArgGen
James Ward8b390432022-08-12 20:48:56 +010010from generator.tosa_utils import get_accum_dtype_from_tgTypes
11from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010012from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from serializer.tosa_serializer import DTypeNames
14from tosa.DType import DType
15from tosa.Op import Op
16from tosa.ResizeMode import ResizeMode
17
18# DTypeNames, DType, Op and ResizeMode are convenience variables to the
19# flatc-generated types that should be enums, but aren't
20
21
22class TosaQuantGen:
23 """QuantizedInfo random generator helper functions.
24
25 Specify with 'qgen': in the operator defintion.
26 """
27
28 def __init__(self):
29 pass
30
31 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000032 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010033
34 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010035 if testGen.args.zeropoint is not None:
36 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010037 return testGen.randInt(-128, 128)
38 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010039 if testGen.args.zeropoint is not None:
40 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010041 return testGen.randInt(0, 256)
42 elif error_name in [
43 ErrorIf.InputZeroPointNotZero,
44 ErrorIf.WeightZeroPointNotZero,
45 ErrorIf.OutputZeroPointNotZero,
46 ]:
47 zero_point = testGen.randInt(-128, 128)
48 if zero_point == 0:
49 zero_point = 1
50 return zero_point
51 return 0
52
53 @staticmethod
54 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010055 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000056 qinfo = [
57 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
58 TosaQuantGen.getZeroPoint(testGen, dtype),
59 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010060 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000061 qinfo = [
62 TosaQuantGen.getZeroPoint(testGen, dtype),
63 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
64 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010065 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000066 qinfo = [
67 TosaQuantGen.getZeroPoint(testGen, dtype),
68 TosaQuantGen.getZeroPoint(testGen, dtype),
69 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010070 return qinfo
71
72 @staticmethod
73 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010074 if isinstance(dtype_or_dtypeList, list):
75 # a list of [input, weights, accumulator] dtypes
76 dtypeList = dtype_or_dtypeList
77 else:
78 # an int, [input, weights, accumulator] dtypes are the same
79 dtypeList = [dtype_or_dtypeList] * 3
80
81 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000082 qinfo = [
83 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
84 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
85 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010086 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000087 qinfo = [
88 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
89 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
90 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010091 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000092 qinfo = [
93 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
94 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
95 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010096 return qinfo
97
98 @staticmethod
99 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100100 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000101 qinfo = [
102 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
103 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
104 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100105 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000106 qinfo = [
107 TosaQuantGen.getZeroPoint(testGen, dtype),
108 TosaQuantGen.getZeroPoint(testGen, dtype),
109 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100110 return qinfo
111
112 @staticmethod
113 def computeMultiplierAndShift(scaleFp, scale32):
114 # Derived from computeMultiplierAndShiftTosaScale32
115 # Provide a floating-point scaling factor and the scale32 parameter
116 # to compute the multiplier and shift
117
118 if scale32:
119 scaleBits = 31
120 else:
121 scaleBits = 15
122
123 m, shift = math.frexp(scaleFp)
124
125 if scaleFp < 0.0:
126 m = -m
127
128 multiplier = round(m * (1 << scaleBits))
129 assert multiplier <= (1 << scaleBits)
130
131 if multiplier == (1 << scaleBits):
132 multiplier = multiplier // 2
133 shift = shift + 1
134
135 shift = (-shift) + scaleBits
136 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
137 # scaleFp, scaleBits, m, multiplier, shift))
138
139 # Adjust multiplier such that shift is in allowed value range.
140 if shift == 0:
141 multiplier = multiplier // 4
142 shift = shift + 2
143 elif shift == 1:
144 multiplier = multiplier // 2
145 shift = shift + 1
146 elif shift == 63:
147 multiplier = multiplier * 2
148 shift = shift - 1
149
150 assert multiplier <= (1 << scaleBits)
151 assert shift >= 2 and shift <= 62
152
153 return multiplier, shift
154
155
156class TosaTensorGen:
157 """Tensor generators create a shape list for the placeholder and const tensor
158 data operands for the operator.
159
160 The actual random data is generated separately for each test.
161 """
162
163 def __init__(self):
164 pass
165
166 @staticmethod
167 def tgBasic(testGen, opName, rank, error_name=None):
168 pl, const = opName["operands"]
169 shape = testGen.makeShape(rank)
170
171 # Constrict the overall size of the shape when creating ERROR_IF tests
172 if error_name:
173 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
174
175 shape_list = []
176 for i in range(pl + const):
177 shape_list.append(shape.copy())
178
179 if error_name == ErrorIf.RankMismatch:
180 if rank == 1 and i != 1:
181 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
182 elif i != 1:
183 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
184
185 return shape_list
186
187 @staticmethod
188 def tgNHWC(testGen, opName, rank, error_name=None):
189 pl, const = opName["operands"]
190
191 if error_name != ErrorIf.WrongRank:
192 assert rank == 4
193
194 shape = testGen.makeShape(rank)
195
196 # Constrict the batch size?
197 if testGen.args.max_batch_size:
198 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
199
200 # Constrict the overall size of the shape when creating ERROR_IF tests
201 if error_name and error_name != ErrorIf.MaxDimExceeded:
202 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
203
204 shape_list = []
205 for i in range(pl + const):
206 shape_list.append(shape.copy())
207
208 return shape_list
209
210 @staticmethod
211 def tgScatter(testGen, opName, rank, error_name=None):
212 pl, const = opName["operands"]
213
214 assert pl == 2
215 assert const == 0
216 if error_name != ErrorIf.WrongRank:
217 assert rank == 3
218
219 values_in_shape = testGen.makeShape(rank)
220
221 # ignore max batch size if target shape is set
222 if testGen.args.max_batch_size and not testGen.args.target_shapes:
223 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
224
225 W = testGen.randInt(
226 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
227 )
228 # Constrict W if one dimension is too large to keep tensor size reasonable
229 if max(values_in_shape) > 5000:
230 W = testGen.randInt(0, 16)
231
232 input_shape = [values_in_shape[0], W, values_in_shape[2]]
233
234 shape_list = []
235 shape_list.append(values_in_shape.copy())
236 shape_list.append(input_shape.copy())
237
238 return shape_list
239
240 @staticmethod
241 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
242 shape = testGen.makeShape(rank)
243
244 pl, const = op["operands"]
245
246 shape_list = []
247
248 # Choose one of the inputs to broadcast
249 # Note: Simplifies OutputShaper code if we don't change first shape for errors
250 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
251 for i in range(pl + const):
252 shape_bcast = shape.copy()
253
254 # If the chosen input, pick a random index to broadcast
255 if i == bcast_idx:
256 fuzz_idx = testGen.randInt(0, rank)
257 if error_name == ErrorIf.DimensionMismatch:
258 shape_bcast[fuzz_idx] += 1
259 elif error_name == ErrorIf.RankMismatch:
260 # Add one rank to the shape (or more for rank of 1)
261 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
262 shape_bcast = np.concatenate(
263 (shape_bcast, testGen.makeShape(extra_ranks))
264 )
265 if rank != 1:
266 # Either keep the extra rank, or remove it
267 new_len = testGen.rng.choice([-2, len(shape_bcast)])
268 shape_bcast = shape_bcast[:new_len]
269 else:
270 shape_bcast[fuzz_idx] = 1
271
272 shape_list.append(shape_bcast)
273
274 return shape_list
275
276 @staticmethod
277 def tgConv2D(testGen, op, rank, error_name=None):
278 pl, const = op["operands"]
279
280 if error_name != ErrorIf.WrongRank:
281 assert rank == 4
282
283 # IFM dimensions are NHWC
284 ifm_shape = testGen.makeShape(rank)
285
286 # Constrict the batch size?
287 if testGen.args.max_batch_size:
288 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
289
290 # Constrict the overall size of the shape when creating ERROR_IF tests
291 if error_name:
292 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
293 ifm_shape, max_dim=24, max_items=10000
294 )
295
296 # Get the filter height/width from the operator parameters
297 filter_hw = op["filter"]
298
299 # Generate a random OFM depth
300 ofm_depth = testGen.makeShape(1)[0]
301
302 # The filter dimensions are OHWI
303 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
304
305 # The bias is OC
306 bias_shape = np.asarray([ofm_depth])
307
308 return [ifm_shape, filter_shape, bias_shape]
309
310 @staticmethod
311 def tgConv3D(testGen, op, rank, error_name=None):
312 pl, const = op["operands"]
313
314 if error_name != ErrorIf.WrongRank:
315 assert rank == 5
316
317 # IFM dimensions are NDHWC
318 ifm_shape = testGen.makeShape(rank)
319
320 # Constrict the batch size?
321 if testGen.args.max_batch_size:
322 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
323
324 # Constrict the overall size of the shape when creating ERROR_IF tests
325 if error_name:
326 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
327 ifm_shape, max_dim=24, max_items=10000
328 )
329
330 # Get the filter depth/height/width from the operator parameters
331 filter_dhw = op["filter"]
332
333 # Generate a random OFM channel
334 ofm_channel = testGen.makeShape(1)[0]
335
336 # The filter dimensions are ODHWI
337 filter_shape = np.asarray(
338 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
339 )
340
341 # The bias is OC
342 bias_shape = np.asarray([ofm_channel])
343
344 return [ifm_shape, filter_shape, bias_shape]
345
346 @staticmethod
347 def tgTransposeConv2D(testGen, op, rank, error_name=None):
348 pl, const = op["operands"]
349
350 if error_name != ErrorIf.WrongRank:
351 assert rank == 4
352
353 # IFM dimensions are NHWC
354 ifm_shape = testGen.makeShape(rank)
355
356 # Constrict the batch size?
357 if testGen.args.max_batch_size:
358 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
359
360 # Constrict the overall size of the shape when creating ERROR_IF tests
361 if error_name:
362 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
363 ifm_shape, max_dim=24, max_items=10000
364 )
365
366 # Get the filter height/width from the operator parameters
367 filter_hw = op["filter"]
368
369 # Generate a random OFM depth
370 ofm_depth = testGen.makeShape(1)[0]
371
372 # The filter dimensions are OHWI
373 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
374
375 # The bias is OC
376 bias_shape = np.asarray([ofm_depth])
377
378 return [ifm_shape, filter_shape, bias_shape]
379
380 @staticmethod
381 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
382 pl, const = op["operands"]
383
384 if error_name != ErrorIf.WrongRank:
385 assert rank == 4
386 assert pl == 1 and const == 2
387
388 # IFM dimensions are NHWC
389 ifm_shape = testGen.makeShape(rank)
390
391 # Constrict the batch size?
392 if testGen.args.max_batch_size:
393 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
394
395 # Constrict the overall size of the shape when creating ERROR_IF tests
396 if error_name:
397 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
398 ifm_shape, max_dim=24, max_items=10000
399 )
400
401 # Get the filter height/width from the operator parameters
402 # Filter is KH, HW, C, M
403 filter_hw = op["filter"]
404
405 # Generate a random OFM depth, but don't let it get too big because
406 # the output depth is M * C
407 filter_m = (
408 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
409 ) + 1
410
411 # The filter dimensions are HWCM
412 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
413
414 # The bias is M * C
415 bias_shape = np.asarray([ifm_shape[3] * filter_m])
416
417 return [ifm_shape, filter_shape, bias_shape]
418
419 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000420 def tgRFFT2d(testGen, op, rank, error_name=None):
421 pl, const = op["operands"]
422
423 if error_name != ErrorIf.WrongRank:
424 assert rank == 3
425 assert pl == 1 and const == 0
426
427 # IFM dimensions are NHW
428 ifm_shape = testGen.makeShape(rank)
429
430 # Select nearest lower power of two from input height and width
431 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
432 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
433
434 # Constrict the overall size of the shape when creating ERROR_IF tests
435 if error_name:
436 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
437
438 # Generate an invalid kernel that is not a power of two
439 if error_name == ErrorIf.KernelNotPowerOfTwo:
440 # We must increment by 2 if current size is 1
441 inc_h = 2 if ifm_shape[1] == 1 else 1
442 inc_w = 2 if ifm_shape[2] == 1 else 1
443 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
444 selected_inc = testGen.rng.choice(inc_choices)
445 ifm_shape[1] += selected_inc[0]
446 ifm_shape[2] += selected_inc[1]
447
448 # Constrict the batch size
449 if testGen.args.max_batch_size:
450 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
451
452 return [ifm_shape]
453
454 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100455 def tgFullyConnected(testGen, op, rank, error_name=None):
456 pl, const = op["operands"]
457
458 if error_name != ErrorIf.WrongRank:
459 assert rank == 2
460
461 input_shape = testGen.makeShape(rank)
462
463 # Constrict the overall size of the shape when creating ERROR_IF tests
464 if error_name:
465 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
466
467 filter_oc = testGen.rng.integers(
468 low=testGen.args.tensor_shape_range[0],
469 high=testGen.args.tensor_shape_range[1],
470 size=1,
471 )[0]
472 filter_shape = np.asarray([filter_oc, input_shape[1]])
473
474 bias_shape = np.asarray([filter_oc])
475
476 return [input_shape, filter_shape, bias_shape]
477
478 @staticmethod
479 def tgMatmul(testGen, op, rank, error_name=None):
480 pl, const = op["operands"]
481
482 if error_name != ErrorIf.WrongRank:
483 assert rank == 3
484 assert pl == 2 and const == 0
485
486 a_shape = testGen.makeShape(rank)
487
488 # Constrict the overall size of the shape when creating ERROR_IF tests
489 if error_name:
490 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
491
492 # Get a random number for b_oc even if target shape is defined
493 b_oc = np.int32(
494 testGen.rng.integers(
495 low=testGen.args.tensor_shape_range[0],
496 high=testGen.args.tensor_shape_range[1],
497 size=1,
498 )
499 )[0]
500 # If N or H is large let b_oc be 1 to reduce output tensor size
501 if max(a_shape) > 1000:
502 b_oc = 1
503
504 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
505 return [a_shape, b_shape]
506
507 @staticmethod
508 def tgConcat(testGen, opName, rank, error_name=None):
509 pl, const = opName["operands"]
510 shape = testGen.makeShape(rank)
511
512 # Create extra tensors to concat.
513 # Take into account value of pl when getting maximum number of concats
514 num_tensors = testGen.randInt(0, 4)
515 shape_list = []
516 for i in range(pl + const + num_tensors):
517 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
518 remove = testGen.rng.choice([True, False])
519 wrongShape = shape.copy()
520
521 if remove and len(shape) > 1:
522 wrongShape = wrongShape[1:]
523 else:
524 wrongShape = list(wrongShape)
525 wrongShape.append(testGen.rng.integers(1, 10))
526
527 shape_list.append(wrongShape)
528 else:
529 shape_list.append(shape.copy())
530
531 return shape_list
532
533 @staticmethod
534 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
535 if error_name in [
536 ErrorIf.AxisSmallerZero,
537 ErrorIf.AxisLargerRank,
538 ErrorIf.ConcatInputRankMismatch,
539 ]:
540 return shapeList
541
542 # Split concat shape along axis to allow for multiple const inputs
543 # without making too many large tensors
544 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
545 # If axis can't be split we still need to invalidate other dimensions
546 if error_name == ErrorIf.ConcatInputDimMismatch:
547 for shape in shapeList[1:]:
548 # Negative test shapeLists are created individually for each test,
549 # so no need to copy the shape before altering it.
550 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
551 return shapeList
552
553 # Create copy of shape we are going to split (so we don't alter shapeList)
554 shape = shapeList[0].copy()
555 # Add original shape as first input
556 new_shapeList = [shape.copy()]
557 length_on_axis = shape[axis]
558 remaining_length = length_on_axis
559 for i in range(len(shapeList) - 2):
560 # Calculate split on axis and remaining value
561 split_shape_val = int(shape[axis] / 2)
562 remaining_length = remaining_length - split_shape_val
563
564 # Append new shape, and set remaining shape
565 shape[axis] = split_shape_val
566 new_shapeList.append(shape.copy())
567
568 # invalidate dimensions
569 if error_name == ErrorIf.ConcatInputDimMismatch:
570 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
571 else:
572 shape[axis] = remaining_length
573
574 if i == len(shapeList) - 3:
575 new_shapeList.append(shape.copy())
576
577 return new_shapeList
578
579
580class TosaTensorValuesGen:
581 """Tensor Value generators create the random data for each test."""
582
583 def __init__(self):
584 pass
585
586 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000587 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100588 pCount, cCount = op["operands"]
589
590 tens = []
591 tens.extend(
592 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
593 )
594 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
595
596 return tens
597
598 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000599 def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100600 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100601 pCount, cCount = op["operands"]
602 assert (
603 pCount == 1 and cCount == 0
604 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100605 # Must create tensors with values within accumulator (int32) negatable
606 # range
607 max_val = (1 << 31) - 1
608 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100609 arr = np.int32(
610 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
611 )
612 placeholders = []
613 placeholders.append(
614 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
615 )
616 return placeholders
617 else:
618 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000619 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100620 )
621
622 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000623 def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100624 if dtypeList[0] == DType.INT32 and error_name is None:
625 # Make sure the operation does not cause value saturation - where
626 # the number wraps due to limited number of bits to store the answer
627 pCount, cCount = op["operands"]
628 assert (
629 pCount == 2 and cCount == 0
630 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
631 placeholders = []
632 add = op["op"] == Op.ADD
633 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
634 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
635 if add:
636 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
637 else:
638 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
639
640 # Work out the saturation limits
641 max_i32 = (1 << 31) - 1
642 min_i32 = -(1 << 31)
643 max_arr = np.full(shapeList[1], max_i32)
644 min_arr = np.full(shapeList[1], min_i32)
645
646 # Find how much values exceed the maximum/minimums
647 sat_max_arr = np.maximum(res_arr - max_arr, 0)
648 sat_min_arr = np.minimum(res_arr - min_arr, 0)
649
650 if not add:
651 # Swap saturation values and negate values as we need to perform opposite operations
652 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
653
654 # Create new array of unsaturated values by clipping values as needed
655 b_unsat_arr = b_arr
656 if (sat_max_arr != 0).any():
657 # Clip values that cause saturation
658 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
659 # Reduce axes in unsaturated tensor to match original tensor
660 for axis, dim in enumerate(b_arr.shape):
661 if dim != b_unsat_arr.shape[axis]:
662 assert (
663 dim == 1
664 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
665 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
666
667 if (sat_min_arr != 0).any():
668 # Clip values that cause saturation
669 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
670 # Reduce axes in unsaturated tensor to match original tensor
671 for axis, dim in enumerate(b_arr.shape):
672 if dim != b_unsat_arr.shape[axis]:
673 assert (
674 dim == 1
675 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
676 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
677
678 placeholders.append(
679 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
680 )
681 placeholders.append(
682 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
683 )
684
685 return placeholders
686 else:
687 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000688 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100689 )
690
691 @staticmethod
692 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000693 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100694 ):
695 if dtypeList[0] in (
696 DType.INT32,
697 DType.INT16,
698 DType.INT8,
699 ):
700 # Limit input tensors with cond_if_binary or while_loop to stop
701 # saturation of add/sub ops with int32 and keep all logical shift
702 # values between 0 to 31 for int16 or int8
703 pCount, cCount = op["operands"]
704 pRemain = pCount
705 placeholders = []
706 for idx, shape in enumerate(shapeList[:]):
707 if dtypeList[0] == DType.INT32:
708 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
709 else:
710 arr = np.int32(
711 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
712 )
713 if pRemain > 0:
714 placeholders.append(
715 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
716 )
717 pRemain -= 1
718 else:
719 placeholders.append(
720 testGen.ser.addConst(shape, dtypeList[idx], arr)
721 )
722
723 return placeholders
724 else:
725 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000726 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100727 )
728
729 @staticmethod
730 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000731 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100732 ):
733 pCount, cCount = op["operands"]
734 # Force value of operand[1] to be within [0, num_bits]
735 assert (
736 pCount == 2 and cCount == 0
737 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
738
739 placeholders = []
740 for idx, shape in enumerate(shapeList[:]):
741 if idx == 1:
742 if dtypeList[idx] == DType.INT8:
743 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
744 elif dtypeList[idx] == DType.INT16:
745 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
746 elif dtypeList[idx] == DType.INT32:
747 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
748 elif error_name == ErrorIf.WrongInputType:
749 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
750 else:
751 raise Exception("OpArithmeticRightShift: invalid input dtype")
752 else:
753 arr = testGen.getRandTensor(shape, dtypeList[idx])
754 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
755
756 return placeholders
757
758 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000759 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100760 # Set datatype of condition tensor to boolean
761 dtypeList[0] = DType.BOOL
762
763 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000764 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100765 )
766
767 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000768 def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100769 if error_name is None:
770 pCount, cCount = op["operands"]
771 assert (
772 pCount == 2 and cCount == 0
773 ), "Op.INTDIV must have 2 placeholders, 0 consts"
774
775 placeholders = []
776
777 # Two invalid cases for Op.INTDIV:
778 # 1. divisor == 0
779 # 2. dividend == -(1<<31) and divisor == -1
780 while True:
781 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
782 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
783
784 if (divisor_arr == 0).any():
785 continue
786
787 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
788 continue
789
790 break
791
792 placeholders.append(
793 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
794 )
795 placeholders.append(
796 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
797 )
798
799 return placeholders
800 else:
801 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000802 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100803 )
804
805 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000806 def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100807 if error_name is None:
808 pCount, cCount = op["operands"]
809 assert (
810 pCount == 2 and cCount == 0
811 ), "Op.MUL must have 2 placeholders, 0 consts"
812
813 tens = []
James Ward24dbc422022-10-19 12:20:31 +0100814 if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100815 tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
816 else:
817 placeholders = []
818
819 # Make sure multiply result in int32 range
820 shift = testArgs[0]
821 if dtypeList[0] == DType.INT8:
822 num_bits = 8
823 elif dtypeList[0] == DType.INT16:
824 num_bits = 16
825 elif dtypeList[0] == DType.INT32:
826 num_bits = 32
827 elif error_name == ErrorIf.WrongInputType:
828 num_bits = 8
829 else:
830 raise Exception("OpMul: invalid input dtype")
831
832 for idx, shape in enumerate(shapeList[:]):
833 low = -(2 ** (num_bits - 1))
834 high = (2 ** (num_bits - 1)) - 1
835
836 a_arr = np.int32(
837 testGen.rng.integers(low=low, high=high, size=shapeList[0])
838 )
839 b_arr = np.int32(
840 testGen.rng.integers(low=low, high=high, size=shapeList[1])
841 )
842
843 i = 0
844 while True:
845
846 a_arr_64 = a_arr.astype(np.int64)
847 b_arr_64 = b_arr.astype(np.int64)
848
849 if shift > 0:
850 rounding = 1 << (shift - 1)
851 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
852 else:
853 result_arr = a_arr_64 * b_arr_64
854
855 if (result_arr > -(2**31)).all() and (
856 result_arr <= ((2**31) - 1)
857 ).all():
858 break
859
860 i = i + 1
861 a_arr = a_arr // 2
862 b_arr = b_arr // 2
863
864 placeholders.append(
865 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
866 )
867 placeholders.append(
868 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
869 )
870
871 tens.extend(placeholders)
872
873 return tens
874 else:
875 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000876 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100877 )
878
879 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000880 def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100881 count = len(shapeList) - testGen.args.num_const_inputs_concat
882 if count < 1:
883 count = 1
884 if testGen.args.num_const_inputs_concat == 0:
885 count = len(shapeList)
886
887 # Ensure axis is an int
888 testArgs[0] = int(testArgs[0])
889
890 shapeList = TosaTensorGen.tgConcatConstInput(
891 testGen, shapeList, testArgs[0], error_name
892 )
893
894 tens = []
895 tens.extend(
896 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
897 )
898 tens.extend(testGen.buildConstTensors(shapeList[count:], dtypeList[count:]))
899
900 return tens
901
902 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000903 def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100904 pCount, cCount = op["operands"]
905 assert (
906 pCount == 2 and cCount == 0
907 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
908 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
909 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
910 placeholders = []
911 placeholders.append(
912 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
913 )
914 placeholders.append(
915 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
916 )
917
918 return placeholders
919
920 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000921 def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100922 if error_name is None:
923 pCount, cCount = op["operands"]
924 assert (
925 pCount == 2 and cCount == 0
926 ), "Op.EQUAL must have 2 placeholders, 0 consts"
927 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
928 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
929 # Using random numbers means that it will be very unlikely that
930 # there are any matching (equal) values, therefore force that
931 # there are twice the number of matching values as the tensor rank
932 for num in range(0, len(shapeList[0]) * 2):
933 a_index = []
934 b_index = []
935 # Choose an index in each axis for the whole shape
936 for axis in range(0, len(shapeList[0])):
937 # Index can be up to the largest dimension in both shapes
938 index = np.int32(
939 testGen.rng.integers(
940 0, max(shapeList[0][axis], shapeList[1][axis])
941 )
942 )
943 # Reduce the index down to a shape's dim for broadcasting
944 a_index.append(min(shapeList[0][axis] - 1, index))
945 b_index.append(min(shapeList[1][axis] - 1, index))
946
947 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
948
949 placeholders = []
950 placeholders.append(
951 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
952 )
953 placeholders.append(
954 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
955 )
956 return placeholders
957 else:
958 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000959 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100960 )
961
962 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000963 def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100964 if dtypeList[0] == DType.INT32:
965 pCount, cCount = op["operands"]
966 assert (
967 pCount == 1 and cCount == 0
968 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
969 # Limit values so that the sum cannot exceed the range of an int32 during
970 # summation of any axis
971 range_val = int((1 << 31) / max(shapeList[0]))
972 values_arr = np.int32(
973 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
974 )
975 placeholders = []
976 placeholders.append(
977 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
978 )
979 return placeholders
980 else:
981 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000982 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100983 )
984
985
986class TosaArgGen:
987 """Argument generators create exhaustive or random lists of attributes for
988 operators that take attributes or other parameters.
989
990 The return value is a list of (descriptive_name, [arglist]) tuples where
991 the descriptive_name is appended to the test name and the arglist is expanded
992 as arguments to the operator build function.
993 """
994
995 def __init__(self):
996 pass
997
998 @staticmethod
999 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1000 """A trivial argument generator for operators that don't take any
1001 non-tensor arguments"""
1002 return [("", [])]
1003
1004 @staticmethod
1005 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1006 """Build the axis argument for operators that take a single axis"""
1007 axes = []
1008 shape = shapeList[0]
1009
1010 if error_name == ErrorIf.AxisSmallerZero:
1011 small_axis = testGen.rng.integers(-5, 0)
1012 axes.append(("axis{}".format(small_axis), [small_axis]))
1013 elif error_name == ErrorIf.AxisLargerRank:
1014 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
1015 axes.append(("axis{}".format(large_axis), [large_axis]))
1016 else:
1017 for a in range(0, len(shape)):
1018 axes.append(("axis{}".format(a), [a]))
1019
1020 return axes
1021
1022 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001023 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001024 arg_list = []
1025
1026 ifm_shape = shapeList[0]
1027 filter_shape = shapeList[1]
1028 # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
1029 k = [int(x) for x in opName.split("_")[-1].split("x")]
1030
James Ward8b390432022-08-12 20:48:56 +01001031 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1032
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001033 # Check the rank
1034 rank = 5 if opName.startswith("conv3d") else 4
1035 if error_name != ErrorIf.WrongRank:
1036 assert len(ifm_shape) == rank
1037 assert len(filter_shape) == rank
1038
1039 # kernel rank omits batch and channels
1040 k_rank = rank - 2
1041 assert len(k) == k_rank
1042
1043 # Generate comprehensive argument lists
1044 # - except for named errors, which use specific invalid value(s)
1045 if error_name == ErrorIf.PadSmallerZero:
1046 p_vals = [testGen.rng.choice(range(-5, 0))]
1047 else:
1048 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1049 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1050 if error_name == ErrorIf.StrideSmallerOne:
1051 # Can't use stride=0, as it is used to derive output shape, as a divisor
1052 s_vals = [testGen.rng.choice(range(-5, 0))]
1053 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001054 # Stride must be greater than 1 to force non-integer error
Jeremy Johnson93d43902022-09-27 12:26:14 +01001055 startStride = 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001056 s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001057 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1058 if error_name == ErrorIf.DilationSmallerOne:
1059 d_vals = [testGen.rng.choice(range(-5, 1))]
1060 else:
1061 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1062 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
1063
1064 if not error_name and testGen.args.oversize:
1065 # add some oversize argument values
1066 if max(ifm_shape) < 64:
1067 bigPadding = 9
1068 paddings.update(
1069 {x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))}
1070 )
1071 bigStride = 8
1072 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
1073 bigDilation = 7
1074 dilations.update(
1075 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1076 )
1077
1078 # There are too many parameter combinations, so generate them sparsely,
1079 # very sparse for negative tests
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001080 sparsity_factor = 2 if error_name else 120
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001081 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
1082 # If there are only a small number of tests, just select them all
1083 if sparsity < 13:
1084 sparsity = 1
1085 # To get a variety of parameter combinations sparsity should not be a
1086 # multiple of 2, 3 or 5
1087 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1088 sparsity += 1
1089
1090 n = 0
1091 for s in sorted(list(strides)):
1092 for p in sorted(list(paddings)):
1093 for d in sorted(list(dilations)):
1094 if (
1095 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001096 # the padded shape must exceed the dilation * kernel to get a positive
1097 # sized output shape
1098 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k[0] - 1)
1099 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k[1] - 1)
1100 and (
1101 k_rank < 3
1102 or ((ifm_shape[3] - 1 + p[4] + p[5]) > d[2] * (k[2] - 1))
1103 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001104 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001105 remainders = []
1106 for index in range(k_rank):
1107 pad_offset = index * 2
1108 remainders.append(
1109 (
1110 ifm_shape[index + 1]
1111 - 1
1112 + p[pad_offset]
1113 + p[pad_offset + 1]
1114 - (k[index] - 1) * d[index]
1115 )
1116 % s[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001117 )
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001118 if (
1119 # the parameters must produce integer exact output
1120 error_name != ErrorIf.ConvOutputShapeNonInteger
1121 and max(remainders) == 0
1122 ) or (
1123 error_name == ErrorIf.ConvOutputShapeNonInteger
1124 and max(remainders) > 0
1125 ):
1126 arg_list.append(
1127 (
James Ward8b390432022-08-12 20:48:56 +01001128 "acc{}_st{}_pad{}_dilat{}".format(
1129 testGen.typeStr(accum_dtype),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001130 "".join([str(x) for x in s]),
1131 "".join([str(x) for x in p]),
1132 "".join([str(x) for x in d]),
1133 ),
James Ward8b390432022-08-12 20:48:56 +01001134 [accum_dtype, s, p, d],
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001135 )
1136 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001137 n += 1
1138
1139 return arg_list
1140
1141 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001142 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1143
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001144 assert isinstance(dtypes, list) or isinstance(
1145 dtypes, tuple
1146 ), f"{dtypes} unexpected"
1147 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01001148
1149 if error_name == ErrorIf.WrongOutputType:
1150 accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype)
1151 elif error_name == ErrorIf.WrongInputType:
1152 # Pick some potentially correct output dtype if input type is incorrect
1153 accum_dtype = DType.INT32
1154 else:
1155 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1156
1157 return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])]
1158
1159 @staticmethod
1160 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
1161 # Get valid accumulate type(s)
1162 if dtype == DType.INT8:
1163 accum_dtypes = [DType.INT32]
1164 elif dtype == DType.INT16:
1165 accum_dtypes = [DType.INT48]
1166 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001167 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001168 elif dtype == DType.BF16:
1169 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001170 elif dtype == DType.FP32:
1171 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001172 elif error_name is None:
1173 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
1174
1175 if error_name == ErrorIf.WrongOutputType:
1176 # Get incorrect output dtype for ErrorIf case
1177 accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)]
1178 elif error_name == ErrorIf.WrongInputType:
1179 # Pick some potentially correct output dtype if input type is incorrect
1180 accum_dtypes = [DType.INT32]
1181
1182 return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes]
1183
1184 @staticmethod
1185 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001186 arg_list = []
1187
1188 ifm_shape = shapeList[0]
1189 filter_shape = shapeList[1]
1190
James Ward8b390432022-08-12 20:48:56 +01001191 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1192
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001193 # Must be rank 4
1194 if error_name != ErrorIf.WrongRank:
1195 assert len(ifm_shape) == 4
1196 assert len(filter_shape) == 4
1197
1198 # Generate comprehensive argument lists
1199 # - except for named errors, which use specific invalid value(s)
Eric Kunzec1a97832022-07-01 16:56:09 -07001200 smallest_padding_size = -min(filter_shape[1], filter_shape[2]) + 1
1201 if error_name == ErrorIf.PadLargerEqualKernel:
1202 max_filter_size = -max(filter_shape[1], filter_shape[2])
1203 p_vals = [testGen.rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001204 else:
Eric Kunzec1a97832022-07-01 16:56:09 -07001205 p_vals = [
1206 x
1207 for x in range(smallest_padding_size, testGen.args.max_conv_padding + 1)
1208 ]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001209 paddings = {x for x in itertools.product(*([p_vals] * 4))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001210 if error_name == ErrorIf.StrideSmallerOne:
1211 # Can't use stride=0, as it is used to derive output shape, as a divisor
1212 s_vals = [testGen.rng.choice(range(-5, 0))]
1213 else:
1214 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
1215 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001216
Jeremy Johnson5860df62022-05-04 15:30:58 +01001217 if not error_name and testGen.args.oversize:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001218 # add some oversize argument values
1219 if max(ifm_shape) < 64:
1220 bigPadding = 9
1221 paddings.update(
Eric Kunzec1a97832022-07-01 16:56:09 -07001222 {
1223 x
1224 for x in itertools.product(
1225 *([[smallest_padding_size, bigPadding]] * 4)
1226 )
1227 }
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001228 )
1229 bigStride = 8
1230 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001231
1232 # There are too many parameter combinations, so generate them sparsely,
1233 # very sparse for negative tests
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001234 sparsity_factor = 2 if error_name else 10
TatWai Chong24594f52022-06-08 00:48:04 -07001235 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001236 # If there are only a small number of tests, just select them all
1237 if sparsity < 13:
1238 sparsity = 1
1239 # To get a variety of parameter combinations sparsity should not be a
1240 # multiple of 2, 3 or 5
1241 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1242 sparsity += 1
1243
1244 n = 0
1245 for s in sorted(list(strides)):
1246 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07001247 if n % sparsity == 0:
1248 # Determine the output shape
Eric Kunzec1a97832022-07-01 16:56:09 -07001249 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + filter_shape[1]
1250 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + filter_shape[2]
TatWai Chong24594f52022-06-08 00:48:04 -07001251 os = [ifm_shape[0], oh, ow, filter_shape[0]]
1252 arg_list.append(
1253 (
James Ward8b390432022-08-12 20:48:56 +01001254 "acc{}_st{}_pad{}_os{}".format(
1255 testGen.typeStr(accum_dtype),
TatWai Chong24594f52022-06-08 00:48:04 -07001256 "".join([str(x) for x in s]),
1257 "".join([str(x) for x in p]),
1258 "x".join([str(x) for x in os]),
1259 ),
James Ward8b390432022-08-12 20:48:56 +01001260 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001261 )
TatWai Chong24594f52022-06-08 00:48:04 -07001262 )
1263 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001264
1265 return arg_list
1266
1267 @staticmethod
1268 def agPad(testGen, opName, shapeList, dtype, error_name=None):
1269 arg_list = []
1270 rank = len(shapeList[0])
1271
1272 # Exhaustively test combinations of padding on each side of each dimension
1273 # - the range of padding values is defined by pad_min and pad_max
1274 # - for padding >9, the name format needs to be more distinctive
1275 pad_min, pad_max = 0, 1
1276 pad_values = [x for x in range(pad_min, pad_max + 1)]
1277 if error_name == ErrorIf.PadSmallerZero:
1278 pad_values = [x for x in range(-2, 0)]
1279 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
1280 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
1281
1282 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
1283 pad_const_int = testGen.getRandNumberDType(dtype)
1284 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00001285 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001286 pad_const_int = 0
1287 pad_const_fp = testGen.getRandNumberDType(dtype)
1288 else:
1289 return []
1290
1291 for paddings in shape_pad_values:
James Ward8b390432022-08-12 20:48:56 +01001292 paddings = list(paddings)
1293 args_valid = True
1294
1295 if error_name == ErrorIf.PadSmallerZero:
1296 # Prevent negative output shapes while ensuring still testing for negative padding
1297 for i in range(rank):
1298 dim_after_padding = (
1299 paddings[i][0] + paddings[i][1] + shapeList[0][i]
1300 )
1301 if dim_after_padding < 1:
1302 paddings[i] = (0, 0)
1303 if all([p > -1 for p in paddings[i]]):
1304 args_valid = False
1305
1306 if args_valid:
1307 name = "pad"
1308 for r in range(rank):
1309 before, after = paddings[r]
1310 name = f"{name}{before}{after}"
1311 arg_list.append(
1312 (name, [np.array(paddings), pad_const_int, pad_const_fp])
1313 )
1314
1315 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
1316 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001317
1318 return arg_list
1319
1320 @staticmethod
1321 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
1322 arg_list = []
1323
1324 shape = shapeList[0]
1325 if error_name != ErrorIf.WrongRank:
1326 assert len(shape) == 4
1327
1328 # Generate comprehensive argument lists
1329 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
1330 paddings = {x for x in itertools.product(*([p_vals] * 4))}
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001331 # Stride must be greater than 1 to force non-integer error
1332 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
1333 s_vals = [x for x in range(startStride, testGen.args.max_pooling_stride + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001334 strides = {x for x in itertools.product(*([s_vals] * 2))}
1335 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)]
1336 kernels = {x for x in itertools.product(*([k_vals] * 2))}
1337
James Ward8b390432022-08-12 20:48:56 +01001338 if opName == "max_pool2d":
1339 accum_dtypes = [None] # max_pool has no accumulate dtype
1340 elif dtype == DType.INT8 or dtype == DType.INT16:
1341 accum_dtypes = [DType.INT32]
1342 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001343 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001344 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001345 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001346 elif error_name is None:
1347 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
1348 else:
1349 # Set to something for the ErrorIf case which has
1350 # incorrect input data-type
1351 accum_dtypes = [DType.INT32]
1352
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001353 if testGen.args.oversize:
1354 # add some oversize argument values
1355 bigStride = 7
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001356 strides.update(
1357 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
1358 )
1359 bigKernel = 9
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001360 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
1361 if max(shape) < 64:
1362 # padding must be less than the kernel size
1363 bigPadding = bigKernel - 1
1364 paddings.update(
1365 {x for x in itertools.product(*([[0, bigPadding]] * 4))}
1366 )
1367
1368 # There are too many parameter combinations, so generate them sparsely,
1369 # very sparse for negative tests
1370 sparsity_factor = 2 if error_name else 500
1371 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
1372
James Ward8b390432022-08-12 20:48:56 +01001373 arg_str = (
1374 "acc{}_st{}_kern{}_pad{}"
1375 if accum_dtypes[0] is not None
1376 else "st{}_kern{}_pad{}"
1377 )
1378
1379 def get_arg_list_element(accum, stride, pad, kern):
1380 # Return tuple containing the formatted argument string and
1381 # the corresponding argument values
1382 arg_str_elems = [
1383 "".join([str(x) for x in stride]),
1384 "".join([str(x) for x in kern]),
1385 "".join([str(x) for x in pad]),
1386 ]
1387 # Note: different order to string
1388 arg_val_elems = [stride, pad, kern]
1389
1390 if accum is not None:
1391 arg_str_elems.insert(0, testGen.typeStr(accum))
1392 arg_val_elems.insert(0, accum)
1393 return (arg_str.format(*arg_str_elems), arg_val_elems)
1394
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001395 n = 0
James Ward8b390432022-08-12 20:48:56 +01001396 for a in accum_dtypes:
1397 for s in sorted(list(strides)):
1398 for p in sorted(list(paddings)):
1399 for k in sorted(list(kernels)):
1400 if error_name in [
1401 ErrorIf.StrideSmallerOne,
1402 ErrorIf.KernelSmallerOne,
1403 ErrorIf.PadSmallerZero,
1404 ErrorIf.PadLargerEqualKernel,
1405 ]:
1406 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
1407 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001408 )
James Ward8b390432022-08-12 20:48:56 +01001409 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
1410 arg_vals = [a, sNew, pNew, kNew]
1411 arg_list.append(get_arg_list_element(*arg_vals))
1412 elif (
1413 n % sparsity == 0
1414 # padding must not exceed the kernel size
1415 and p[0] < k[0]
1416 and p[1] < k[0]
1417 and p[2] < k[1]
1418 and p[3] < k[1]
1419 # the padded shape must exceed the kernel size
1420 and (shape[1] + p[0] + p[1]) > k[0]
1421 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001422 ):
James Ward8b390432022-08-12 20:48:56 +01001423 remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0]
1424 remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1]
1425 if (
1426 # the parameters must produce integer exact output
1427 error_name != ErrorIf.PoolingOutputShapeNonInteger
1428 and remainder_h == 0
1429 and remainder_w == 0
1430 ) or (
1431 error_name == ErrorIf.PoolingOutputShapeNonInteger
1432 and (remainder_h != 0 or remainder_w != 0)
1433 ):
1434 arg_vals = [a, s, p, k]
1435 arg_list.append(get_arg_list_element(*arg_vals))
1436 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001437
1438 return arg_list
1439
1440 @staticmethod
1441 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
1442 arg_list = []
1443
1444 # Enumerate the output types here
1445 if error_name == ErrorIf.WrongOutputType:
1446 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
1447 elif inDtype == DType.INT8:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001448 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001449 elif inDtype == DType.INT16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001450 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001451 elif inDtype == DType.INT32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001452 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001453 elif inDtype == DType.BOOL:
1454 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01001455 elif inDtype == DType.FP16:
1456 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward24dbc422022-10-19 12:20:31 +01001457 elif inDtype == DType.BF16:
1458 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001459 elif inDtype == DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001460 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
1461 elif error_name == ErrorIf.WrongInputType:
1462 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001463 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001464 else:
1465 raise Exception("Unexpected input dtype: {}".format(inDtype))
1466
1467 for dtype in dtypeList:
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001468 arg_list.append(("out{}".format(testGen.typeStr(dtype)), [dtype]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001469
1470 return arg_list
1471
1472 @staticmethod
1473 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
1474 arg_list = []
1475
1476 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001477 for outDtype in [
1478 DType.UINT8,
1479 DType.INT8,
1480 DType.INT16,
1481 DType.INT32,
1482 DType.UINT16,
1483 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001484 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001485 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001486 and error_name == ErrorIf.OutputZeroPointNotZero
1487 ):
1488 continue
1489 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001490 outDtype != DType.UINT16
1491 and error_name == ErrorIf.U16OutputZeroPointNotValid
1492 ) or (
1493 inDtype != DType.UINT16
1494 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001495 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001496 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001497 continue
1498 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001499 inDtype == DType.UINT8
1500 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001501 and error_name != ErrorIf.WrongOutputType
1502 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001503 # The only output dtypes for UINT8 are INT8/INT16, skip all others
1504 continue
1505 if (
1506 inDtype not in [DType.INT8, DType.INT16]
1507 and outDtype == DType.UINT8
1508 and error_name != ErrorIf.WrongOutputType
1509 ):
1510 # The only input dtypes for UINT8 are INT8/INT16, skip all others
1511 continue
1512 if (
1513 inDtype == DType.UINT16
1514 and outDtype != DType.INT16
1515 and error_name != ErrorIf.WrongOutputType
1516 ):
1517 # The only output dtype for UINT16 is INT16, skip all others
1518 continue
1519 if (
1520 inDtype != DType.INT16
1521 and outDtype == DType.UINT16
1522 and error_name != ErrorIf.WrongOutputType
1523 ):
1524 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001525 continue
1526 if (
1527 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001528 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001529 ):
1530 continue
1531
1532 for scale32 in [False, True]:
1533 if error_name == ErrorIf.ScaleTrue and not scale32:
1534 continue
1535 elif error_name == ErrorIf.ScaleNotTrue and scale32:
1536 continue
1537 for double_round in [False, True]:
1538 if error_name == ErrorIf.ScaleNotTrue and not double_round:
1539 continue
1540 for per_channel in [False, True]:
1541
1542 if (
1543 inDtype == DType.INT48
1544 and scale32
1545 and error_name != ErrorIf.ScaleTrue
1546 ):
1547 # Illegal condition. Must be scale32=False
1548 continue
1549 if (
1550 double_round
1551 and not scale32
1552 and error_name != ErrorIf.ScaleNotTrue
1553 ):
1554 # Illegal condition. ERROR_IF(!scale32 && double_round)
1555 continue
1556
1557 arg_list.append(
1558 (
1559 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001560 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001561 int(scale32),
1562 int(double_round),
1563 int(per_channel),
1564 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001565 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001566 )
1567 )
1568
1569 return arg_list
1570
1571 @staticmethod
1572 def agMul(testGen, opName, shapeList, dtype, error_name=None):
1573 arg_list = []
1574
1575 if dtype is DType.INT32:
1576 for p in range(testGen.args.num_rand_permutations):
1577
1578 shift = testGen.randInt(0, 32)
1579
1580 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
1581 else:
1582 arg_list.append(("perm0_shift0", [0]))
1583
1584 return arg_list
1585
1586 @staticmethod
1587 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
1588 arg_list = []
1589
1590 arg_list.append(("roundTrue", [True]))
1591 arg_list.append(("roundFalse", [False]))
1592
1593 return arg_list
1594
1595 # Helper function for reshape. Gets some factors of a larger number.
1596 @staticmethod
1597 def getFactors(val, start=1):
1598 factors = []
1599
1600 for i in range(start, int(np.sqrt(val)) + 1):
1601 if (val % i) == 0:
1602 factors.append(i)
1603
1604 return factors
1605
1606 @staticmethod
1607 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
1608 arg_list = []
1609
1610 origShape = shapeList[0]
1611
1612 totalElements = 1
1613 for s in origShape:
1614 totalElements *= s
1615
1616 # This code is NOT fast. Fortunately, the numbers are fairly small.
1617 factors = TosaArgGen.getFactors(totalElements)
1618
1619 for p in range(testGen.args.num_rand_permutations):
1620 newRank = testGen.randInt(1, 7)
1621 if len(factors) < newRank:
1622 continue
1623
1624 found = True
1625 # escape_counter breaks while loop if it continues on for too long
1626 escape_counter = 0
1627 while found:
1628 newShape = []
1629 # Generate newShape ensuring it isn't a duplicate
1630 remainingElements = totalElements
1631 shuffledFactors = testGen.rng.permutation(factors)
1632 for i in range(1, newRank):
1633 # pick rank-1 factors
1634 newShape.append(shuffledFactors[0])
1635 remainingElements = remainingElements // shuffledFactors[0]
1636 shuffledFactors = testGen.rng.permutation(
1637 TosaArgGen.getFactors(remainingElements)
1638 )
1639 newShape.append(remainingElements)
1640
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001641 # Check for duplicates
1642 found = False
1643 for name, other_shape in arg_list:
1644 if other_shape[0] == newShape:
1645 found = True
1646 break
1647
1648 escape_counter += 1
1649 if escape_counter >= 100:
1650 break
1651
1652 if not found:
1653 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
1654
1655 return arg_list
1656
1657 @staticmethod
1658 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
1659 arg_list = []
1660
1661 ifm_shape = shapeList[0]
1662
1663 if error_name == ErrorIf.IndexOutsideBounds:
1664 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
1665 incorrect_small_index = range(-len(ifm_shape), 0)
1666 permutations = [p for p in itertools.permutations(incorrect_large_index)]
1667 permutations.extend(
1668 [p for p in itertools.permutations(incorrect_small_index)]
1669 )
1670 elif error_name == ErrorIf.IndexUsedTwice:
1671 # Create list with a duplicated index
1672 perm_range = list(range(len(ifm_shape)))
1673 index_choice = testGen.rng.choice(range(len(perm_range)))
1674 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
1675 permutations = [p for p in itertools.permutations(perm_range)]
1676
1677 else:
1678 # Get all permutations
1679 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
1680
1681 # Limit to possible permutations from shape dimension or argument setting
1682 limit = min(len(permutations), testGen.args.num_rand_permutations)
1683
1684 # Get random permutation generator that uses all permutations
1685 random_permutations = testGen.rng.permutation(permutations)
1686
1687 # Create list of required amount of permutations
1688 arg_list = [
1689 ("perm{}".format(p), [random_permutations[p].tolist()])
1690 for p in range(limit)
1691 ]
1692 return arg_list
1693
1694 @staticmethod
1695 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
1696 arg_list = []
1697
1698 ifm_shape = shapeList[0]
1699 rank = len(ifm_shape)
1700
1701 for p in range(testGen.args.num_rand_permutations):
1702 start = []
1703 size = []
1704
1705 valid = True
1706
1707 for i in range(rank):
1708 if ifm_shape[i] > 1:
1709 start.append(testGen.randInt(0, ifm_shape[i]))
1710 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
1711
1712 # Invalid slice size?
1713 if size[i] == 0:
1714 valid = False
1715 else:
1716 start.append(0)
1717 size.append(1)
1718
1719 if valid:
1720 # If ERROR_IF test required then incorrect start, size will be returned
1721 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
1722 testGen, error_name, ifm_shape, start, size
1723 )
1724 arg_list.append(("perm{}".format(p), [start, size]))
1725 return arg_list
1726
1727 @staticmethod
1728 def agTile(testGen, opName, shapeList, dtype, error_name=None):
1729 arg_list = []
1730
1731 ifm_shape = shapeList[0]
1732 rank = len(ifm_shape)
1733
1734 for p in range(testGen.args.num_rand_permutations):
1735
1736 # Pick a few random, but small multiple values
1737 # because otherwise this has a tendency to generate
1738 # enormous tensors
1739 multiples = []
1740 for i in range(rank):
1741 if ifm_shape[i] > 1000:
1742 # Multiple of 1 if ifm_shape dimension is large to reduce
1743 # tensor size
1744 multiples.append(1)
1745 elif max(ifm_shape) > 1000:
1746 multiples.append(2)
1747 else:
1748 multiples.append(testGen.randInt(1, 4))
1749 arg_list.append(("perm{}".format(p), [multiples]))
1750
1751 return arg_list
1752
1753 @staticmethod
1754 def agResize(testGen, opName, shapeList, dtype, error_name=None):
1755 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001756 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001757
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001758 def get_aspect_ratio_resize_params():
1759 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
1760 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
1761 invert = testGen.rng.choice((False, True))
1762 letterbox = testGen.rng.choice((False, True))
1763
1764 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
1765 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
1766 scale_y_d = scale_x_d = 1
1767 offset_x = offset_y = 0
1768
1769 if letterbox:
1770 max_border = scale_y_n
1771 border_y = testGen.randInt(low=0, high=max_border)
1772 border_x = 0
1773 else:
1774 # Pillarboxing
1775 border_y = 0
1776 max_border = scale_x_n
1777 border_x = testGen.randInt(low=0, high=max_border)
1778
1779 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1780 offset = (offset_y, offset_x)
1781 border = (border_y, border_x)
1782
1783 return scale, offset, border
1784
1785 def get_upscale_downscale_params():
1786 valid_params = False
1787 while not valid_params:
1788 upscale = testGen.rng.choice((False, True))
1789
1790 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
1791 origin_sampling = testGen.rng.choice((False, True))
1792
1793 if upscale:
1794 shift = testGen.randInt(low=1, high=4)
1795 scale_x_d = scale_y_d = 1
1796 scale_x_n = scale_y_n = (
1797 1 << shift if origin_sampling else 2 << shift
1798 )
1799 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
1800 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
1801 else:
1802 scale_x_n = 1
1803 scale_y_n = 1
1804
1805 # Return list of valid scale_*_d values (max value 4) given input dim shape
1806 def get_valid_denom(ifm_dim):
1807 return [x for x in range(1, 5) if ifm_dim % x == 1]
1808
1809 # Generate list of valid downscale values and choose one randomly
1810 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
1811 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
1812
1813 if not valid_scale_y_ds and not valid_scale_x_ds:
1814 # Bad parameters, skip
1815 continue
1816
1817 if not valid_scale_y_ds:
1818 scale_y_d = 1
1819 else:
1820 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
1821
1822 if not valid_scale_x_ds:
1823 scale_x_d = 1
1824 else:
1825 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
1826
1827 border_x = border_y = 0
1828 offset_y = testGen.randInt(0, 16 * scale_y_n)
1829 offset_x = testGen.randInt(0, 16 * scale_x_n)
1830 valid_params = True
1831
1832 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1833 offset = (offset_y, offset_x)
1834 border = (border_y, border_x)
1835 return scale, offset, border
1836
1837 def get_rand_params():
1838 # Scale
1839 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
1840 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
1841
1842 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
1843 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
1844
1845 # Offsets and border within the scale
1846 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
1847 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
1848 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
1849 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
1850
1851 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1852 offset = (offset_y, offset_x)
1853 border = (border_y, border_x)
1854 return scale, offset, border
1855
1856 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001857 # Exclude illegal {mode, type} configurations. Pick legal output types
1858 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1859 outputDTypeList = [DType.INT8]
1860 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1861 outputDTypeList = [DType.INT16]
1862 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1863 outputDTypeList = [DType.INT32]
1864 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1865 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01001866 elif dtype == DType.FP16:
1867 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01001868 elif dtype == DType.BF16:
1869 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001870 elif dtype == DType.FP32:
1871 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001872 elif error_name == ErrorIf.WrongInputType:
1873 # If an incorrect input type is used then we set a 'correct'
1874 # output type to avoid other errors
1875 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
1876 else:
1877 continue
1878
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001879 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
1880
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001881 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001882 perm = 0
1883 while perm < testGen.args.num_rand_permutations:
1884 # Random choice of type of params we are testing
1885 _rnd_param_fn = testGen.rng.choice(
1886 (
1887 get_rand_params,
1888 get_upscale_downscale_params,
1889 get_aspect_ratio_resize_params,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001890 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001891 )
1892 scale, offset, border = _rnd_param_fn()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001893
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001894 # Expand params for bounds-checking
1895 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
1896 (offset_y, offset_x) = offset
1897 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001898
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001899 # Make sure output dimensions OH and OW are integers
1900 partial_output_y = (
1901 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
1902 )
1903 partial_output_x = (
1904 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
1905 )
1906 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
1907 if (
1908 partial_output_y % scale_y_d == 0
1909 and partial_output_x % scale_x_d == 0
1910 ):
1911 # Skip this test as it doesn't produce NonInteger output
1912 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001913 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001914 else:
1915 while partial_output_y % scale_y_d != 0:
1916 scale_y_d -= 1
1917 while partial_output_x % scale_x_d != 0:
1918 scale_x_d -= 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001919
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001920 output_y = partial_output_y // scale_y_d + 1
1921 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001922
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001923 if (
1924 output_y >= testGen.args.max_resize_output_dim
1925 or output_x >= testGen.args.max_resize_output_dim
1926 ) and error_name is None:
1927 # Skip positive test if output dim will be too high
1928 # Avoid high test latency and OOM issues
1929 perm += 1
1930 continue
1931
1932 if (
1933 output_y <= 0
1934 or output_y >= MAX_RESIZE_DIMENSION
1935 or output_x <= 0
1936 or output_x >= MAX_RESIZE_DIMENSION
1937 ):
1938 # Output dimensions out of scope
1939 if error_name is not None and perm > 0:
1940 # As long as we have one ERROR_IF test, don't worry
1941 # about creating all the other permutations
1942 perm += 1
1943 continue
1944
1945 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
1946 (
1947 output_y + scale_y_d >= MAX_RESIZE_DIMENSION
1948 and output_y - scale_y_d < 1
1949 )
1950 or (
1951 output_x + scale_x_d >= MAX_RESIZE_DIMENSION
1952 and output_x - scale_x_d < 1
1953 )
1954 ):
1955 # Can't create a negative test with these params as it
1956 # will create invalid output size
1957 if perm > 0:
1958 perm += 1
1959 continue
1960
1961 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
1962 offset = [offset_y, offset_x]
1963 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001964
1965 # Common for all data types
1966 if error_name is not None:
1967 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001968 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001969 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001970 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001971 outputDTypeNew,
1972 ) = TosaErrorIfArgGen.eiResizeErrorIf(
1973 testGen,
1974 error_name,
1975 mode,
1976 dtype,
1977 shapeList,
1978 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001979 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001980 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001981 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001982 )
1983 else:
1984 outputDTypeNew = outputDType
1985
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001986 arg_to_append = (
1987 arg_str.format(
1988 "N" if mode == ResizeMode.NEAREST else "B",
1989 testGen.typeStr(outputDTypeNew),
1990 scale[0],
1991 scale[1],
1992 scale[2],
1993 scale[3],
1994 offset[0],
1995 offset[1],
1996 border[0],
1997 border[1],
1998 ),
1999 [
2000 mode,
2001 scale,
2002 offset,
2003 border,
2004 dtype,
2005 outputDTypeNew,
2006 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002007 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002008 if arg_to_append in arg_list:
2009 # Skip already generated test params
2010 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002011
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002012 # Valid permutation
2013 perm += 1
2014 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002015 return arg_list
2016
2017 @staticmethod
2018 def agTable(testGen, opName, shapeList, dtype, error_name=None):
2019 arg_list = []
2020
2021 if dtype == DType.INT8:
2022 table = np.int32(
2023 testGen.rng.integers(low=-128, high=128, size=[256])
2024 ).tolist()
2025 else: # INT16
2026 table = np.int32(
2027 testGen.rng.integers(low=-32768, high=32768, size=[513])
2028 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07002029 # Make sure all slopes are within REQUIRE min/max 16-bit int
2030 for idx in range(len(table) - 1):
2031 slope = table[idx + 1] - table[idx]
2032 # Alter the next table entry to force the slope to be ok
2033 if slope > 32767:
2034 table[idx + 1] -= slope - 32767
2035 if slope < -32768:
2036 table[idx + 1] -= slope + 32768
2037 slope = table[idx + 1] - table[idx]
2038 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002039 arg_list.append(
2040 (
2041 "",
2042 [table],
2043 )
2044 )
2045 return arg_list
2046
2047 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
2048 # CondIf generates the condition values here.
2049 # Convert to tensors in the build function, along with the
2050 # then and else blocks
2051 arg_list = []
2052
2053 for c in [False, True]:
2054 arg_list.append(("cond{}".format(int(c)), [c]))
2055
2056 return arg_list
2057
2058 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
2059 # While loop: 0 iterations, 1, more than 1
2060 arg_list = []
2061
2062 for iter in [0, 1, 4]:
2063 arg_list.append(("iter{}".format(iter), [iter]))
2064
2065 return arg_list