blob: 9209d9c0534b771c678676e38eaa5b5601c7e989 [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_error_if import ErrorIf
9from generator.tosa_error_if import TosaErrorIfArgGen
James Ward8b390432022-08-12 20:48:56 +010010from generator.tosa_utils import get_accum_dtype_from_tgTypes
11from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010012from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from serializer.tosa_serializer import DTypeNames
14from tosa.DType import DType
15from tosa.Op import Op
16from tosa.ResizeMode import ResizeMode
17
18# DTypeNames, DType, Op and ResizeMode are convenience variables to the
19# flatc-generated types that should be enums, but aren't
20
21
22class TosaQuantGen:
23 """QuantizedInfo random generator helper functions.
24
25 Specify with 'qgen': in the operator defintion.
26 """
27
28 def __init__(self):
29 pass
30
31 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000032 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010033
34 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010035 if testGen.args.zeropoint is not None:
36 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010037 return testGen.randInt(-128, 128)
38 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010039 if testGen.args.zeropoint is not None:
40 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010041 return testGen.randInt(0, 256)
42 elif error_name in [
43 ErrorIf.InputZeroPointNotZero,
44 ErrorIf.WeightZeroPointNotZero,
45 ErrorIf.OutputZeroPointNotZero,
46 ]:
47 zero_point = testGen.randInt(-128, 128)
48 if zero_point == 0:
49 zero_point = 1
50 return zero_point
51 return 0
52
53 @staticmethod
54 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010055 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000056 qinfo = [
57 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
58 TosaQuantGen.getZeroPoint(testGen, dtype),
59 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010060 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000061 qinfo = [
62 TosaQuantGen.getZeroPoint(testGen, dtype),
63 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
64 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010065 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000066 qinfo = [
67 TosaQuantGen.getZeroPoint(testGen, dtype),
68 TosaQuantGen.getZeroPoint(testGen, dtype),
69 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010070 return qinfo
71
72 @staticmethod
73 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010074 if isinstance(dtype_or_dtypeList, list):
75 # a list of [input, weights, accumulator] dtypes
76 dtypeList = dtype_or_dtypeList
77 else:
78 # an int, [input, weights, accumulator] dtypes are the same
79 dtypeList = [dtype_or_dtypeList] * 3
80
81 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000082 qinfo = [
83 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
84 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
85 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010086 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000087 qinfo = [
88 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
89 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
90 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010091 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000092 qinfo = [
93 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
94 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
95 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010096 return qinfo
97
98 @staticmethod
99 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100100 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000101 qinfo = [
102 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
103 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
104 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100105 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000106 qinfo = [
107 TosaQuantGen.getZeroPoint(testGen, dtype),
108 TosaQuantGen.getZeroPoint(testGen, dtype),
109 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100110 return qinfo
111
112 @staticmethod
113 def computeMultiplierAndShift(scaleFp, scale32):
114 # Derived from computeMultiplierAndShiftTosaScale32
115 # Provide a floating-point scaling factor and the scale32 parameter
116 # to compute the multiplier and shift
117
118 if scale32:
119 scaleBits = 31
120 else:
121 scaleBits = 15
122
123 m, shift = math.frexp(scaleFp)
124
125 if scaleFp < 0.0:
126 m = -m
127
128 multiplier = round(m * (1 << scaleBits))
129 assert multiplier <= (1 << scaleBits)
130
131 if multiplier == (1 << scaleBits):
132 multiplier = multiplier // 2
133 shift = shift + 1
134
135 shift = (-shift) + scaleBits
136 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
137 # scaleFp, scaleBits, m, multiplier, shift))
138
139 # Adjust multiplier such that shift is in allowed value range.
140 if shift == 0:
141 multiplier = multiplier // 4
142 shift = shift + 2
143 elif shift == 1:
144 multiplier = multiplier // 2
145 shift = shift + 1
146 elif shift == 63:
147 multiplier = multiplier * 2
148 shift = shift - 1
149
150 assert multiplier <= (1 << scaleBits)
151 assert shift >= 2 and shift <= 62
152
153 return multiplier, shift
154
155
156class TosaTensorGen:
157 """Tensor generators create a shape list for the placeholder and const tensor
158 data operands for the operator.
159
160 The actual random data is generated separately for each test.
161 """
162
163 def __init__(self):
164 pass
165
166 @staticmethod
167 def tgBasic(testGen, opName, rank, error_name=None):
168 pl, const = opName["operands"]
169 shape = testGen.makeShape(rank)
170
171 # Constrict the overall size of the shape when creating ERROR_IF tests
172 if error_name:
173 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
174
175 shape_list = []
176 for i in range(pl + const):
177 shape_list.append(shape.copy())
178
Luke Huttona4e48ca2023-02-22 11:53:48 +0000179 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100180 if error_name == ErrorIf.RankMismatch:
181 if rank == 1 and i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
183 elif i != 1:
184 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
185
186 return shape_list
187
188 @staticmethod
189 def tgNHWC(testGen, opName, rank, error_name=None):
190 pl, const = opName["operands"]
191
192 if error_name != ErrorIf.WrongRank:
193 assert rank == 4
194
195 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000196 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100197
198 # Constrict the overall size of the shape when creating ERROR_IF tests
199 if error_name and error_name != ErrorIf.MaxDimExceeded:
200 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
201
202 shape_list = []
203 for i in range(pl + const):
204 shape_list.append(shape.copy())
205
206 return shape_list
207
208 @staticmethod
209 def tgScatter(testGen, opName, rank, error_name=None):
210 pl, const = opName["operands"]
211
212 assert pl == 2
213 assert const == 0
214 if error_name != ErrorIf.WrongRank:
215 assert rank == 3
216
217 values_in_shape = testGen.makeShape(rank)
218
219 # ignore max batch size if target shape is set
220 if testGen.args.max_batch_size and not testGen.args.target_shapes:
James Ward30124a82023-02-02 14:56:33 +0000221 values_in_shape[0] = min(values_in_shape[0], testGen.args.max_batch_size)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100222
223 W = testGen.randInt(
224 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
225 )
226 # Constrict W if one dimension is too large to keep tensor size reasonable
227 if max(values_in_shape) > 5000:
228 W = testGen.randInt(0, 16)
229
230 input_shape = [values_in_shape[0], W, values_in_shape[2]]
231
232 shape_list = []
233 shape_list.append(values_in_shape.copy())
234 shape_list.append(input_shape.copy())
235
236 return shape_list
237
238 @staticmethod
239 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
240 shape = testGen.makeShape(rank)
241
242 pl, const = op["operands"]
243
244 shape_list = []
245
246 # Choose one of the inputs to broadcast
247 # Note: Simplifies OutputShaper code if we don't change first shape for errors
248 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
249 for i in range(pl + const):
250 shape_bcast = shape.copy()
251
252 # If the chosen input, pick a random index to broadcast
253 if i == bcast_idx:
254 fuzz_idx = testGen.randInt(0, rank)
255 if error_name == ErrorIf.DimensionMismatch:
256 shape_bcast[fuzz_idx] += 1
257 elif error_name == ErrorIf.RankMismatch:
258 # Add one rank to the shape (or more for rank of 1)
259 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
260 shape_bcast = np.concatenate(
261 (shape_bcast, testGen.makeShape(extra_ranks))
262 )
263 if rank != 1:
264 # Either keep the extra rank, or remove it
265 new_len = testGen.rng.choice([-2, len(shape_bcast)])
266 shape_bcast = shape_bcast[:new_len]
267 else:
268 shape_bcast[fuzz_idx] = 1
269
270 shape_list.append(shape_bcast)
271
272 return shape_list
273
274 @staticmethod
275 def tgConv2D(testGen, op, rank, error_name=None):
276 pl, const = op["operands"]
277
278 if error_name != ErrorIf.WrongRank:
279 assert rank == 4
280
281 # IFM dimensions are NHWC
282 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000283 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100284
285 # Constrict the overall size of the shape when creating ERROR_IF tests
286 if error_name:
287 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
288 ifm_shape, max_dim=24, max_items=10000
289 )
290
291 # Get the filter height/width from the operator parameters
292 filter_hw = op["filter"]
293
294 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000295 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100296
297 # The filter dimensions are OHWI
298 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
299
300 # The bias is OC
301 bias_shape = np.asarray([ofm_depth])
302
303 return [ifm_shape, filter_shape, bias_shape]
304
305 @staticmethod
306 def tgConv3D(testGen, op, rank, error_name=None):
307 pl, const = op["operands"]
308
309 if error_name != ErrorIf.WrongRank:
310 assert rank == 5
311
312 # IFM dimensions are NDHWC
313 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000314 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100315
316 # Constrict the overall size of the shape when creating ERROR_IF tests
317 if error_name:
318 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
319 ifm_shape, max_dim=24, max_items=10000
320 )
321
322 # Get the filter depth/height/width from the operator parameters
323 filter_dhw = op["filter"]
324
325 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000326 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100327
328 # The filter dimensions are ODHWI
329 filter_shape = np.asarray(
330 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
331 )
332
333 # The bias is OC
334 bias_shape = np.asarray([ofm_channel])
335
336 return [ifm_shape, filter_shape, bias_shape]
337
338 @staticmethod
339 def tgTransposeConv2D(testGen, op, rank, error_name=None):
340 pl, const = op["operands"]
341
342 if error_name != ErrorIf.WrongRank:
343 assert rank == 4
344
345 # IFM dimensions are NHWC
346 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000347 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100348
349 # Constrict the overall size of the shape when creating ERROR_IF tests
350 if error_name:
351 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
352 ifm_shape, max_dim=24, max_items=10000
353 )
354
355 # Get the filter height/width from the operator parameters
356 filter_hw = op["filter"]
357
358 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000359 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100360
361 # The filter dimensions are OHWI
362 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
363
364 # The bias is OC
365 bias_shape = np.asarray([ofm_depth])
366
367 return [ifm_shape, filter_shape, bias_shape]
368
369 @staticmethod
370 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
371 pl, const = op["operands"]
372
373 if error_name != ErrorIf.WrongRank:
374 assert rank == 4
375 assert pl == 1 and const == 2
376
377 # IFM dimensions are NHWC
378 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000379 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100380
381 # Constrict the overall size of the shape when creating ERROR_IF tests
382 if error_name:
383 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
384 ifm_shape, max_dim=24, max_items=10000
385 )
386
387 # Get the filter height/width from the operator parameters
388 # Filter is KH, HW, C, M
389 filter_hw = op["filter"]
390
391 # Generate a random OFM depth, but don't let it get too big because
392 # the output depth is M * C
393 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000394 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100395 ) + 1
396
397 # The filter dimensions are HWCM
398 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
399
400 # The bias is M * C
401 bias_shape = np.asarray([ifm_shape[3] * filter_m])
402
403 return [ifm_shape, filter_shape, bias_shape]
404
405 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000406 def tgFFT2d(testGen, op, rank, error_name=None):
407 pl, const = op["operands"]
408
409 if error_name != ErrorIf.WrongRank:
410 assert rank == 3
411 assert pl == 2 and const == 0
412
413 # IFM dimensions are NHW
414 ifm_shape = testGen.makeShape(rank)
415
416 # Select nearest lower power of two from input height and width
417 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
418 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
419
420 # Constrict the overall size of the shape when creating ERROR_IF tests
421 if error_name:
422 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
423
424 # Generate an invalid kernel that is not a power of two
425 if error_name == ErrorIf.KernelNotPowerOfTwo:
426 inc_h = 2 if ifm_shape[1] == 1 else 1
427 inc_w = 2 if ifm_shape[2] == 1 else 1
428 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
429 selected_inc = testGen.rng.choice(inc_choices)
430 ifm_shape[1] += selected_inc[0]
431 ifm_shape[2] += selected_inc[1]
432
433 ifm_shape = testGen.constrictBatchSize(ifm_shape)
434
435 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
436 if error_name == ErrorIf.FFTInputShapeMismatch:
437 modify_shape = testGen.rng.choice([0, 1])
438 # Only modify kernel (H, W)
439 modify_dim = testGen.rng.choice([1, 2])
440 ifm_shapes[modify_shape][modify_dim] *= 2
441
442 return [ifm_shapes[0], ifm_shapes[1]]
443
444 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000445 def tgRFFT2d(testGen, op, rank, error_name=None):
446 pl, const = op["operands"]
447
448 if error_name != ErrorIf.WrongRank:
449 assert rank == 3
450 assert pl == 1 and const == 0
451
452 # IFM dimensions are NHW
453 ifm_shape = testGen.makeShape(rank)
454
455 # Select nearest lower power of two from input height and width
456 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
457 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
458
459 # Constrict the overall size of the shape when creating ERROR_IF tests
460 if error_name:
461 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
462
463 # Generate an invalid kernel that is not a power of two
464 if error_name == ErrorIf.KernelNotPowerOfTwo:
465 # We must increment by 2 if current size is 1
466 inc_h = 2 if ifm_shape[1] == 1 else 1
467 inc_w = 2 if ifm_shape[2] == 1 else 1
468 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
469 selected_inc = testGen.rng.choice(inc_choices)
470 ifm_shape[1] += selected_inc[0]
471 ifm_shape[2] += selected_inc[1]
472
James Ward30124a82023-02-02 14:56:33 +0000473 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000474
475 return [ifm_shape]
476
477 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100478 def tgFullyConnected(testGen, op, rank, error_name=None):
479 pl, const = op["operands"]
480
481 if error_name != ErrorIf.WrongRank:
482 assert rank == 2
483
484 input_shape = testGen.makeShape(rank)
485
486 # Constrict the overall size of the shape when creating ERROR_IF tests
487 if error_name:
488 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
489
490 filter_oc = testGen.rng.integers(
491 low=testGen.args.tensor_shape_range[0],
492 high=testGen.args.tensor_shape_range[1],
493 size=1,
494 )[0]
495 filter_shape = np.asarray([filter_oc, input_shape[1]])
496
497 bias_shape = np.asarray([filter_oc])
498
499 return [input_shape, filter_shape, bias_shape]
500
501 @staticmethod
502 def tgMatmul(testGen, op, rank, error_name=None):
503 pl, const = op["operands"]
504
505 if error_name != ErrorIf.WrongRank:
506 assert rank == 3
507 assert pl == 2 and const == 0
508
509 a_shape = testGen.makeShape(rank)
510
511 # Constrict the overall size of the shape when creating ERROR_IF tests
512 if error_name:
513 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
514
515 # Get a random number for b_oc even if target shape is defined
516 b_oc = np.int32(
517 testGen.rng.integers(
518 low=testGen.args.tensor_shape_range[0],
519 high=testGen.args.tensor_shape_range[1],
520 size=1,
521 )
522 )[0]
523 # If N or H is large let b_oc be 1 to reduce output tensor size
524 if max(a_shape) > 1000:
525 b_oc = 1
526
527 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
528 return [a_shape, b_shape]
529
530 @staticmethod
531 def tgConcat(testGen, opName, rank, error_name=None):
532 pl, const = opName["operands"]
533 shape = testGen.makeShape(rank)
534
535 # Create extra tensors to concat.
536 # Take into account value of pl when getting maximum number of concats
537 num_tensors = testGen.randInt(0, 4)
538 shape_list = []
539 for i in range(pl + const + num_tensors):
540 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
541 remove = testGen.rng.choice([True, False])
542 wrongShape = shape.copy()
543
544 if remove and len(shape) > 1:
545 wrongShape = wrongShape[1:]
546 else:
547 wrongShape = list(wrongShape)
548 wrongShape.append(testGen.rng.integers(1, 10))
549
550 shape_list.append(wrongShape)
551 else:
552 shape_list.append(shape.copy())
553
554 return shape_list
555
556 @staticmethod
557 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
558 if error_name in [
559 ErrorIf.AxisSmallerZero,
560 ErrorIf.AxisLargerRank,
561 ErrorIf.ConcatInputRankMismatch,
562 ]:
563 return shapeList
564
565 # Split concat shape along axis to allow for multiple const inputs
566 # without making too many large tensors
567 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
568 # If axis can't be split we still need to invalidate other dimensions
569 if error_name == ErrorIf.ConcatInputDimMismatch:
570 for shape in shapeList[1:]:
571 # Negative test shapeLists are created individually for each test,
572 # so no need to copy the shape before altering it.
573 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
574 return shapeList
575
576 # Create copy of shape we are going to split (so we don't alter shapeList)
577 shape = shapeList[0].copy()
578 # Add original shape as first input
579 new_shapeList = [shape.copy()]
580 length_on_axis = shape[axis]
581 remaining_length = length_on_axis
582 for i in range(len(shapeList) - 2):
583 # Calculate split on axis and remaining value
584 split_shape_val = int(shape[axis] / 2)
585 remaining_length = remaining_length - split_shape_val
586
587 # Append new shape, and set remaining shape
588 shape[axis] = split_shape_val
589 new_shapeList.append(shape.copy())
590
591 # invalidate dimensions
592 if error_name == ErrorIf.ConcatInputDimMismatch:
593 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
594 else:
595 shape[axis] = remaining_length
596
597 if i == len(shapeList) - 3:
598 new_shapeList.append(shape.copy())
599
600 return new_shapeList
601
602
603class TosaTensorValuesGen:
604 """Tensor Value generators create the random data for each test."""
605
606 def __init__(self):
607 pass
608
609 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000610 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100611 pCount, cCount = op["operands"]
612
613 tens = []
614 tens.extend(
615 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
616 )
617 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
618
619 return tens
620
621 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000622 def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100623 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100624 pCount, cCount = op["operands"]
625 assert (
626 pCount == 1 and cCount == 0
627 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100628 # Must create tensors with values within accumulator (int32) negatable
629 # range
630 max_val = (1 << 31) - 1
631 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100632 arr = np.int32(
633 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
634 )
635 placeholders = []
636 placeholders.append(
637 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
638 )
639 return placeholders
640 else:
641 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000642 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100643 )
644
645 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000646 def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100647 if dtypeList[0] == DType.INT32 and error_name is None:
648 # Make sure the operation does not cause value saturation - where
649 # the number wraps due to limited number of bits to store the answer
650 pCount, cCount = op["operands"]
651 assert (
652 pCount == 2 and cCount == 0
653 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
654 placeholders = []
655 add = op["op"] == Op.ADD
656 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
657 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
658 if add:
659 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
660 else:
661 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
662
663 # Work out the saturation limits
664 max_i32 = (1 << 31) - 1
665 min_i32 = -(1 << 31)
666 max_arr = np.full(shapeList[1], max_i32)
667 min_arr = np.full(shapeList[1], min_i32)
668
669 # Find how much values exceed the maximum/minimums
670 sat_max_arr = np.maximum(res_arr - max_arr, 0)
671 sat_min_arr = np.minimum(res_arr - min_arr, 0)
672
673 if not add:
674 # Swap saturation values and negate values as we need to perform opposite operations
675 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
676
677 # Create new array of unsaturated values by clipping values as needed
678 b_unsat_arr = b_arr
679 if (sat_max_arr != 0).any():
680 # Clip values that cause saturation
681 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
682 # Reduce axes in unsaturated tensor to match original tensor
683 for axis, dim in enumerate(b_arr.shape):
684 if dim != b_unsat_arr.shape[axis]:
685 assert (
686 dim == 1
687 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
688 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
689
690 if (sat_min_arr != 0).any():
691 # Clip values that cause saturation
692 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
693 # Reduce axes in unsaturated tensor to match original tensor
694 for axis, dim in enumerate(b_arr.shape):
695 if dim != b_unsat_arr.shape[axis]:
696 assert (
697 dim == 1
698 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
699 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
700
701 placeholders.append(
702 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
703 )
704 placeholders.append(
705 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
706 )
707
708 return placeholders
709 else:
710 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000711 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100712 )
713
714 @staticmethod
715 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000716 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100717 ):
718 if dtypeList[0] in (
719 DType.INT32,
720 DType.INT16,
721 DType.INT8,
722 ):
723 # Limit input tensors with cond_if_binary or while_loop to stop
724 # saturation of add/sub ops with int32 and keep all logical shift
725 # values between 0 to 31 for int16 or int8
726 pCount, cCount = op["operands"]
727 pRemain = pCount
728 placeholders = []
729 for idx, shape in enumerate(shapeList[:]):
730 if dtypeList[0] == DType.INT32:
731 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
732 else:
733 arr = np.int32(
734 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
735 )
736 if pRemain > 0:
737 placeholders.append(
738 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
739 )
740 pRemain -= 1
741 else:
742 placeholders.append(
743 testGen.ser.addConst(shape, dtypeList[idx], arr)
744 )
745
746 return placeholders
747 else:
748 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000749 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100750 )
751
752 @staticmethod
753 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000754 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100755 ):
756 pCount, cCount = op["operands"]
757 # Force value of operand[1] to be within [0, num_bits]
758 assert (
759 pCount == 2 and cCount == 0
760 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
761
762 placeholders = []
763 for idx, shape in enumerate(shapeList[:]):
764 if idx == 1:
765 if dtypeList[idx] == DType.INT8:
766 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
767 elif dtypeList[idx] == DType.INT16:
768 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
769 elif dtypeList[idx] == DType.INT32:
770 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
771 elif error_name == ErrorIf.WrongInputType:
772 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
773 else:
774 raise Exception("OpArithmeticRightShift: invalid input dtype")
775 else:
776 arr = testGen.getRandTensor(shape, dtypeList[idx])
777 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
778
779 return placeholders
780
781 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000782 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100783 # Set datatype of condition tensor to boolean
784 dtypeList[0] = DType.BOOL
785
786 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000787 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100788 )
789
790 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000791 def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100792 if error_name is None:
793 pCount, cCount = op["operands"]
794 assert (
795 pCount == 2 and cCount == 0
796 ), "Op.INTDIV must have 2 placeholders, 0 consts"
797
798 placeholders = []
799
800 # Two invalid cases for Op.INTDIV:
801 # 1. divisor == 0
802 # 2. dividend == -(1<<31) and divisor == -1
803 while True:
804 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
805 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
806
807 if (divisor_arr == 0).any():
808 continue
809
810 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
811 continue
812
813 break
814
815 placeholders.append(
816 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
817 )
818 placeholders.append(
819 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
820 )
821
822 return placeholders
823 else:
824 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000825 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100826 )
827
828 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000829 def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100830 if error_name is None:
831 pCount, cCount = op["operands"]
832 assert (
833 pCount == 2 and cCount == 0
834 ), "Op.MUL must have 2 placeholders, 0 consts"
835
836 tens = []
James Ward24dbc422022-10-19 12:20:31 +0100837 if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100838 tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
839 else:
840 placeholders = []
841
842 # Make sure multiply result in int32 range
843 shift = testArgs[0]
844 if dtypeList[0] == DType.INT8:
845 num_bits = 8
846 elif dtypeList[0] == DType.INT16:
847 num_bits = 16
848 elif dtypeList[0] == DType.INT32:
849 num_bits = 32
850 elif error_name == ErrorIf.WrongInputType:
851 num_bits = 8
852 else:
853 raise Exception("OpMul: invalid input dtype")
854
855 for idx, shape in enumerate(shapeList[:]):
856 low = -(2 ** (num_bits - 1))
857 high = (2 ** (num_bits - 1)) - 1
858
859 a_arr = np.int32(
860 testGen.rng.integers(low=low, high=high, size=shapeList[0])
861 )
862 b_arr = np.int32(
863 testGen.rng.integers(low=low, high=high, size=shapeList[1])
864 )
865
866 i = 0
867 while True:
868
869 a_arr_64 = a_arr.astype(np.int64)
870 b_arr_64 = b_arr.astype(np.int64)
871
872 if shift > 0:
873 rounding = 1 << (shift - 1)
874 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
875 else:
876 result_arr = a_arr_64 * b_arr_64
877
878 if (result_arr > -(2**31)).all() and (
879 result_arr <= ((2**31) - 1)
880 ).all():
881 break
882
883 i = i + 1
884 a_arr = a_arr // 2
885 b_arr = b_arr // 2
886
887 placeholders.append(
888 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
889 )
890 placeholders.append(
891 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
892 )
893
894 tens.extend(placeholders)
895
896 return tens
897 else:
898 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000899 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100900 )
901
902 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000903 def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100904 count = len(shapeList) - testGen.args.num_const_inputs_concat
905 if count < 1:
906 count = 1
907 if testGen.args.num_const_inputs_concat == 0:
908 count = len(shapeList)
909
910 # Ensure axis is an int
911 testArgs[0] = int(testArgs[0])
912
913 shapeList = TosaTensorGen.tgConcatConstInput(
914 testGen, shapeList, testArgs[0], error_name
915 )
916
917 tens = []
918 tens.extend(
919 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
920 )
921 tens.extend(testGen.buildConstTensors(shapeList[count:], dtypeList[count:]))
922
923 return tens
924
925 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000926 def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100927 pCount, cCount = op["operands"]
928 assert (
929 pCount == 2 and cCount == 0
930 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
931 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
932 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
933 placeholders = []
934 placeholders.append(
935 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
936 )
937 placeholders.append(
938 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
939 )
940
941 return placeholders
942
943 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000944 def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100945 if error_name is None:
946 pCount, cCount = op["operands"]
947 assert (
948 pCount == 2 and cCount == 0
949 ), "Op.EQUAL must have 2 placeholders, 0 consts"
950 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
951 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
952 # Using random numbers means that it will be very unlikely that
953 # there are any matching (equal) values, therefore force that
954 # there are twice the number of matching values as the tensor rank
955 for num in range(0, len(shapeList[0]) * 2):
956 a_index = []
957 b_index = []
958 # Choose an index in each axis for the whole shape
959 for axis in range(0, len(shapeList[0])):
960 # Index can be up to the largest dimension in both shapes
961 index = np.int32(
962 testGen.rng.integers(
963 0, max(shapeList[0][axis], shapeList[1][axis])
964 )
965 )
966 # Reduce the index down to a shape's dim for broadcasting
967 a_index.append(min(shapeList[0][axis] - 1, index))
968 b_index.append(min(shapeList[1][axis] - 1, index))
969
970 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
971
972 placeholders = []
973 placeholders.append(
974 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
975 )
976 placeholders.append(
977 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
978 )
979 return placeholders
980 else:
981 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000982 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100983 )
984
985 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000986 def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100987 if dtypeList[0] == DType.INT32:
988 pCount, cCount = op["operands"]
989 assert (
990 pCount == 1 and cCount == 0
991 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
992 # Limit values so that the sum cannot exceed the range of an int32 during
993 # summation of any axis
994 range_val = int((1 << 31) / max(shapeList[0]))
995 values_arr = np.int32(
996 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
997 )
998 placeholders = []
999 placeholders.append(
1000 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1001 )
1002 return placeholders
1003 else:
1004 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001005 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001006 )
1007
1008
1009class TosaArgGen:
1010 """Argument generators create exhaustive or random lists of attributes for
1011 operators that take attributes or other parameters.
1012
1013 The return value is a list of (descriptive_name, [arglist]) tuples where
1014 the descriptive_name is appended to the test name and the arglist is expanded
1015 as arguments to the operator build function.
1016 """
1017
1018 def __init__(self):
1019 pass
1020
1021 @staticmethod
1022 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1023 """A trivial argument generator for operators that don't take any
1024 non-tensor arguments"""
1025 return [("", [])]
1026
1027 @staticmethod
1028 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1029 """Build the axis argument for operators that take a single axis"""
1030 axes = []
1031 shape = shapeList[0]
1032
1033 if error_name == ErrorIf.AxisSmallerZero:
1034 small_axis = testGen.rng.integers(-5, 0)
1035 axes.append(("axis{}".format(small_axis), [small_axis]))
1036 elif error_name == ErrorIf.AxisLargerRank:
1037 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
1038 axes.append(("axis{}".format(large_axis), [large_axis]))
1039 else:
1040 for a in range(0, len(shape)):
1041 axes.append(("axis{}".format(a), [a]))
1042
1043 return axes
1044
1045 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001046 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001047 arg_list = []
1048
1049 ifm_shape = shapeList[0]
1050 filter_shape = shapeList[1]
1051 # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
1052 k = [int(x) for x in opName.split("_")[-1].split("x")]
1053
James Ward8b390432022-08-12 20:48:56 +01001054 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1055
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001056 # Check the rank
1057 rank = 5 if opName.startswith("conv3d") else 4
1058 if error_name != ErrorIf.WrongRank:
1059 assert len(ifm_shape) == rank
1060 assert len(filter_shape) == rank
1061
1062 # kernel rank omits batch and channels
1063 k_rank = rank - 2
1064 assert len(k) == k_rank
1065
1066 # Generate comprehensive argument lists
1067 # - except for named errors, which use specific invalid value(s)
1068 if error_name == ErrorIf.PadSmallerZero:
1069 p_vals = [testGen.rng.choice(range(-5, 0))]
1070 else:
1071 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1072 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1073 if error_name == ErrorIf.StrideSmallerOne:
1074 # Can't use stride=0, as it is used to derive output shape, as a divisor
1075 s_vals = [testGen.rng.choice(range(-5, 0))]
1076 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001077 # Stride must be greater than 1 to force non-integer error
Jeremy Johnson93d43902022-09-27 12:26:14 +01001078 startStride = 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001079 s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001080 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1081 if error_name == ErrorIf.DilationSmallerOne:
1082 d_vals = [testGen.rng.choice(range(-5, 1))]
1083 else:
1084 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1085 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
1086
1087 if not error_name and testGen.args.oversize:
1088 # add some oversize argument values
1089 if max(ifm_shape) < 64:
1090 bigPadding = 9
1091 paddings.update(
1092 {x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))}
1093 )
1094 bigStride = 8
1095 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
1096 bigDilation = 7
1097 dilations.update(
1098 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1099 )
1100
1101 # There are too many parameter combinations, so generate them sparsely,
1102 # very sparse for negative tests
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001103 sparsity_factor = 2 if error_name else 120
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001104 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
1105 # If there are only a small number of tests, just select them all
1106 if sparsity < 13:
1107 sparsity = 1
1108 # To get a variety of parameter combinations sparsity should not be a
1109 # multiple of 2, 3 or 5
1110 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1111 sparsity += 1
1112
1113 n = 0
1114 for s in sorted(list(strides)):
1115 for p in sorted(list(paddings)):
1116 for d in sorted(list(dilations)):
1117 if (
1118 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001119 # the padded shape must exceed the dilation * kernel to get a positive
1120 # sized output shape
1121 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k[0] - 1)
1122 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k[1] - 1)
1123 and (
1124 k_rank < 3
1125 or ((ifm_shape[3] - 1 + p[4] + p[5]) > d[2] * (k[2] - 1))
1126 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001127 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001128 remainders = []
1129 for index in range(k_rank):
1130 pad_offset = index * 2
1131 remainders.append(
1132 (
1133 ifm_shape[index + 1]
1134 - 1
1135 + p[pad_offset]
1136 + p[pad_offset + 1]
1137 - (k[index] - 1) * d[index]
1138 )
1139 % s[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001140 )
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001141 if (
1142 # the parameters must produce integer exact output
1143 error_name != ErrorIf.ConvOutputShapeNonInteger
1144 and max(remainders) == 0
1145 ) or (
1146 error_name == ErrorIf.ConvOutputShapeNonInteger
1147 and max(remainders) > 0
1148 ):
1149 arg_list.append(
1150 (
James Ward8b390432022-08-12 20:48:56 +01001151 "acc{}_st{}_pad{}_dilat{}".format(
1152 testGen.typeStr(accum_dtype),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001153 "".join([str(x) for x in s]),
1154 "".join([str(x) for x in p]),
1155 "".join([str(x) for x in d]),
1156 ),
James Ward8b390432022-08-12 20:48:56 +01001157 [accum_dtype, s, p, d],
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001158 )
1159 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001160 n += 1
1161
1162 return arg_list
1163
1164 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001165 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1166
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001167 assert isinstance(dtypes, list) or isinstance(
1168 dtypes, tuple
1169 ), f"{dtypes} unexpected"
1170 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01001171
1172 if error_name == ErrorIf.WrongOutputType:
1173 accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype)
1174 elif error_name == ErrorIf.WrongInputType:
1175 # Pick some potentially correct output dtype if input type is incorrect
1176 accum_dtype = DType.INT32
1177 else:
1178 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1179
1180 return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])]
1181
1182 @staticmethod
1183 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
1184 # Get valid accumulate type(s)
1185 if dtype == DType.INT8:
1186 accum_dtypes = [DType.INT32]
1187 elif dtype == DType.INT16:
1188 accum_dtypes = [DType.INT48]
1189 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001190 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001191 elif dtype == DType.BF16:
1192 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001193 elif dtype == DType.FP32:
1194 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001195 elif error_name is None:
1196 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
1197
1198 if error_name == ErrorIf.WrongOutputType:
1199 # Get incorrect output dtype for ErrorIf case
1200 accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)]
1201 elif error_name == ErrorIf.WrongInputType:
1202 # Pick some potentially correct output dtype if input type is incorrect
1203 accum_dtypes = [DType.INT32]
1204
1205 return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes]
1206
1207 @staticmethod
1208 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001209 arg_list = []
1210
1211 ifm_shape = shapeList[0]
1212 filter_shape = shapeList[1]
1213
James Ward8b390432022-08-12 20:48:56 +01001214 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1215
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001216 # Must be rank 4
1217 if error_name != ErrorIf.WrongRank:
1218 assert len(ifm_shape) == 4
1219 assert len(filter_shape) == 4
1220
1221 # Generate comprehensive argument lists
1222 # - except for named errors, which use specific invalid value(s)
Eric Kunzec1a97832022-07-01 16:56:09 -07001223 smallest_padding_size = -min(filter_shape[1], filter_shape[2]) + 1
1224 if error_name == ErrorIf.PadLargerEqualKernel:
1225 max_filter_size = -max(filter_shape[1], filter_shape[2])
1226 p_vals = [testGen.rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001227 else:
Eric Kunzec1a97832022-07-01 16:56:09 -07001228 p_vals = [
1229 x
1230 for x in range(smallest_padding_size, testGen.args.max_conv_padding + 1)
1231 ]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001232 paddings = {x for x in itertools.product(*([p_vals] * 4))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001233 if error_name == ErrorIf.StrideSmallerOne:
1234 # Can't use stride=0, as it is used to derive output shape, as a divisor
1235 s_vals = [testGen.rng.choice(range(-5, 0))]
1236 else:
1237 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
1238 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001239
Jeremy Johnson5860df62022-05-04 15:30:58 +01001240 if not error_name and testGen.args.oversize:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001241 # add some oversize argument values
1242 if max(ifm_shape) < 64:
1243 bigPadding = 9
1244 paddings.update(
Eric Kunzec1a97832022-07-01 16:56:09 -07001245 {
1246 x
1247 for x in itertools.product(
1248 *([[smallest_padding_size, bigPadding]] * 4)
1249 )
1250 }
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001251 )
1252 bigStride = 8
1253 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001254
1255 # There are too many parameter combinations, so generate them sparsely,
1256 # very sparse for negative tests
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001257 sparsity_factor = 2 if error_name else 10
TatWai Chong24594f52022-06-08 00:48:04 -07001258 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001259 # If there are only a small number of tests, just select them all
1260 if sparsity < 13:
1261 sparsity = 1
1262 # To get a variety of parameter combinations sparsity should not be a
1263 # multiple of 2, 3 or 5
1264 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1265 sparsity += 1
1266
1267 n = 0
1268 for s in sorted(list(strides)):
1269 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07001270 if n % sparsity == 0:
1271 # Determine the output shape
Eric Kunzec1a97832022-07-01 16:56:09 -07001272 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + filter_shape[1]
1273 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + filter_shape[2]
TatWai Chong24594f52022-06-08 00:48:04 -07001274 os = [ifm_shape[0], oh, ow, filter_shape[0]]
1275 arg_list.append(
1276 (
James Ward8b390432022-08-12 20:48:56 +01001277 "acc{}_st{}_pad{}_os{}".format(
1278 testGen.typeStr(accum_dtype),
TatWai Chong24594f52022-06-08 00:48:04 -07001279 "".join([str(x) for x in s]),
1280 "".join([str(x) for x in p]),
1281 "x".join([str(x) for x in os]),
1282 ),
James Ward8b390432022-08-12 20:48:56 +01001283 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001284 )
TatWai Chong24594f52022-06-08 00:48:04 -07001285 )
1286 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001287
1288 return arg_list
1289
1290 @staticmethod
1291 def agPad(testGen, opName, shapeList, dtype, error_name=None):
1292 arg_list = []
1293 rank = len(shapeList[0])
1294
1295 # Exhaustively test combinations of padding on each side of each dimension
1296 # - the range of padding values is defined by pad_min and pad_max
1297 # - for padding >9, the name format needs to be more distinctive
1298 pad_min, pad_max = 0, 1
1299 pad_values = [x for x in range(pad_min, pad_max + 1)]
1300 if error_name == ErrorIf.PadSmallerZero:
1301 pad_values = [x for x in range(-2, 0)]
1302 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
1303 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
1304
1305 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
1306 pad_const_int = testGen.getRandNumberDType(dtype)
1307 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00001308 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001309 pad_const_int = 0
1310 pad_const_fp = testGen.getRandNumberDType(dtype)
1311 else:
1312 return []
1313
1314 for paddings in shape_pad_values:
James Ward8b390432022-08-12 20:48:56 +01001315 paddings = list(paddings)
1316 args_valid = True
1317
1318 if error_name == ErrorIf.PadSmallerZero:
1319 # Prevent negative output shapes while ensuring still testing for negative padding
1320 for i in range(rank):
1321 dim_after_padding = (
1322 paddings[i][0] + paddings[i][1] + shapeList[0][i]
1323 )
1324 if dim_after_padding < 1:
1325 paddings[i] = (0, 0)
1326 if all([p > -1 for p in paddings[i]]):
1327 args_valid = False
1328
1329 if args_valid:
1330 name = "pad"
1331 for r in range(rank):
1332 before, after = paddings[r]
1333 name = f"{name}{before}{after}"
1334 arg_list.append(
1335 (name, [np.array(paddings), pad_const_int, pad_const_fp])
1336 )
1337
1338 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
1339 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001340
1341 return arg_list
1342
1343 @staticmethod
1344 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
1345 arg_list = []
1346
1347 shape = shapeList[0]
1348 if error_name != ErrorIf.WrongRank:
1349 assert len(shape) == 4
1350
1351 # Generate comprehensive argument lists
1352 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
1353 paddings = {x for x in itertools.product(*([p_vals] * 4))}
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001354 # Stride must be greater than 1 to force non-integer error
1355 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
1356 s_vals = [x for x in range(startStride, testGen.args.max_pooling_stride + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001357 strides = {x for x in itertools.product(*([s_vals] * 2))}
1358 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)]
1359 kernels = {x for x in itertools.product(*([k_vals] * 2))}
1360
James Ward8b390432022-08-12 20:48:56 +01001361 if opName == "max_pool2d":
1362 accum_dtypes = [None] # max_pool has no accumulate dtype
1363 elif dtype == DType.INT8 or dtype == DType.INT16:
1364 accum_dtypes = [DType.INT32]
1365 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001366 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001367 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001368 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001369 elif error_name is None:
1370 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
1371 else:
1372 # Set to something for the ErrorIf case which has
1373 # incorrect input data-type
1374 accum_dtypes = [DType.INT32]
1375
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001376 if testGen.args.oversize:
1377 # add some oversize argument values
1378 bigStride = 7
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001379 strides.update(
1380 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
1381 )
1382 bigKernel = 9
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001383 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
1384 if max(shape) < 64:
1385 # padding must be less than the kernel size
1386 bigPadding = bigKernel - 1
1387 paddings.update(
1388 {x for x in itertools.product(*([[0, bigPadding]] * 4))}
1389 )
1390
1391 # There are too many parameter combinations, so generate them sparsely,
1392 # very sparse for negative tests
1393 sparsity_factor = 2 if error_name else 500
1394 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
1395
James Ward8b390432022-08-12 20:48:56 +01001396 arg_str = (
1397 "acc{}_st{}_kern{}_pad{}"
1398 if accum_dtypes[0] is not None
1399 else "st{}_kern{}_pad{}"
1400 )
1401
1402 def get_arg_list_element(accum, stride, pad, kern):
1403 # Return tuple containing the formatted argument string and
1404 # the corresponding argument values
1405 arg_str_elems = [
1406 "".join([str(x) for x in stride]),
1407 "".join([str(x) for x in kern]),
1408 "".join([str(x) for x in pad]),
1409 ]
1410 # Note: different order to string
1411 arg_val_elems = [stride, pad, kern]
1412
1413 if accum is not None:
1414 arg_str_elems.insert(0, testGen.typeStr(accum))
1415 arg_val_elems.insert(0, accum)
1416 return (arg_str.format(*arg_str_elems), arg_val_elems)
1417
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001418 n = 0
James Ward8b390432022-08-12 20:48:56 +01001419 for a in accum_dtypes:
1420 for s in sorted(list(strides)):
1421 for p in sorted(list(paddings)):
1422 for k in sorted(list(kernels)):
1423 if error_name in [
1424 ErrorIf.StrideSmallerOne,
1425 ErrorIf.KernelSmallerOne,
1426 ErrorIf.PadSmallerZero,
1427 ErrorIf.PadLargerEqualKernel,
1428 ]:
1429 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
1430 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001431 )
James Ward8b390432022-08-12 20:48:56 +01001432 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
1433 arg_vals = [a, sNew, pNew, kNew]
1434 arg_list.append(get_arg_list_element(*arg_vals))
1435 elif (
1436 n % sparsity == 0
1437 # padding must not exceed the kernel size
1438 and p[0] < k[0]
1439 and p[1] < k[0]
1440 and p[2] < k[1]
1441 and p[3] < k[1]
1442 # the padded shape must exceed the kernel size
1443 and (shape[1] + p[0] + p[1]) > k[0]
1444 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001445 ):
James Ward8b390432022-08-12 20:48:56 +01001446 remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0]
1447 remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1]
1448 if (
1449 # the parameters must produce integer exact output
1450 error_name != ErrorIf.PoolingOutputShapeNonInteger
1451 and remainder_h == 0
1452 and remainder_w == 0
1453 ) or (
1454 error_name == ErrorIf.PoolingOutputShapeNonInteger
1455 and (remainder_h != 0 or remainder_w != 0)
1456 ):
1457 arg_vals = [a, s, p, k]
1458 arg_list.append(get_arg_list_element(*arg_vals))
1459 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001460
1461 return arg_list
1462
1463 @staticmethod
1464 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
1465 arg_list = []
1466
1467 # Enumerate the output types here
1468 if error_name == ErrorIf.WrongOutputType:
1469 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
1470 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00001471 dtypeList = [
1472 DType.BOOL,
1473 DType.INT16,
1474 DType.INT32,
1475 DType.FP16,
1476 DType.BF16,
1477 DType.FP32,
1478 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001479 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00001480 dtypeList = [
1481 DType.BOOL,
1482 DType.INT8,
1483 DType.INT32,
1484 DType.FP16,
1485 DType.BF16,
1486 DType.FP32,
1487 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001488 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00001489 dtypeList = [
1490 DType.BOOL,
1491 DType.INT8,
1492 DType.INT16,
1493 DType.FP16,
1494 DType.BF16,
1495 DType.FP32,
1496 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001497 elif inDtype == DType.BOOL:
1498 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01001499 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00001500 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001501 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00001502 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001503 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00001504 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001505 elif error_name == ErrorIf.WrongInputType:
1506 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001507 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001508 else:
1509 raise Exception("Unexpected input dtype: {}".format(inDtype))
1510
1511 for dtype in dtypeList:
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001512 arg_list.append(("out{}".format(testGen.typeStr(dtype)), [dtype]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001513
1514 return arg_list
1515
1516 @staticmethod
1517 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
1518 arg_list = []
1519
1520 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001521 for outDtype in [
1522 DType.UINT8,
1523 DType.INT8,
1524 DType.INT16,
1525 DType.INT32,
1526 DType.UINT16,
1527 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001528 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001529 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001530 and error_name == ErrorIf.OutputZeroPointNotZero
1531 ):
1532 continue
1533 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001534 outDtype != DType.UINT16
1535 and error_name == ErrorIf.U16OutputZeroPointNotValid
1536 ) or (
1537 inDtype != DType.UINT16
1538 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001539 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001540 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001541 continue
1542 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001543 inDtype == DType.UINT8
1544 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001545 and error_name != ErrorIf.WrongOutputType
1546 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001547 # The only output dtypes for UINT8 are INT8/INT16, skip all others
1548 continue
1549 if (
1550 inDtype not in [DType.INT8, DType.INT16]
1551 and outDtype == DType.UINT8
1552 and error_name != ErrorIf.WrongOutputType
1553 ):
1554 # The only input dtypes for UINT8 are INT8/INT16, skip all others
1555 continue
1556 if (
1557 inDtype == DType.UINT16
1558 and outDtype != DType.INT16
1559 and error_name != ErrorIf.WrongOutputType
1560 ):
1561 # The only output dtype for UINT16 is INT16, skip all others
1562 continue
1563 if (
1564 inDtype != DType.INT16
1565 and outDtype == DType.UINT16
1566 and error_name != ErrorIf.WrongOutputType
1567 ):
1568 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001569 continue
1570 if (
1571 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001572 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001573 ):
1574 continue
1575
1576 for scale32 in [False, True]:
1577 if error_name == ErrorIf.ScaleTrue and not scale32:
1578 continue
1579 elif error_name == ErrorIf.ScaleNotTrue and scale32:
1580 continue
1581 for double_round in [False, True]:
1582 if error_name == ErrorIf.ScaleNotTrue and not double_round:
1583 continue
1584 for per_channel in [False, True]:
1585
1586 if (
1587 inDtype == DType.INT48
1588 and scale32
1589 and error_name != ErrorIf.ScaleTrue
1590 ):
1591 # Illegal condition. Must be scale32=False
1592 continue
1593 if (
1594 double_round
1595 and not scale32
1596 and error_name != ErrorIf.ScaleNotTrue
1597 ):
1598 # Illegal condition. ERROR_IF(!scale32 && double_round)
1599 continue
1600
1601 arg_list.append(
1602 (
1603 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001604 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001605 int(scale32),
1606 int(double_round),
1607 int(per_channel),
1608 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001609 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001610 )
1611 )
1612
1613 return arg_list
1614
1615 @staticmethod
1616 def agMul(testGen, opName, shapeList, dtype, error_name=None):
1617 arg_list = []
1618
1619 if dtype is DType.INT32:
1620 for p in range(testGen.args.num_rand_permutations):
1621
1622 shift = testGen.randInt(0, 32)
1623
1624 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
1625 else:
1626 arg_list.append(("perm0_shift0", [0]))
1627
1628 return arg_list
1629
1630 @staticmethod
1631 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
1632 arg_list = []
1633
1634 arg_list.append(("roundTrue", [True]))
1635 arg_list.append(("roundFalse", [False]))
1636
1637 return arg_list
1638
Luke Hutton57287132023-02-06 14:54:18 +00001639 @staticmethod
1640 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
1641 arg_list = []
1642
1643 arg_list.append(("inverseTrue", [True]))
1644 arg_list.append(("inverseFalse", [False]))
1645
1646 return arg_list
1647
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001648 # Helper function for reshape. Gets some factors of a larger number.
1649 @staticmethod
1650 def getFactors(val, start=1):
1651 factors = []
1652
1653 for i in range(start, int(np.sqrt(val)) + 1):
1654 if (val % i) == 0:
1655 factors.append(i)
1656
1657 return factors
1658
1659 @staticmethod
1660 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
1661 arg_list = []
1662
1663 origShape = shapeList[0]
1664
1665 totalElements = 1
1666 for s in origShape:
1667 totalElements *= s
1668
1669 # This code is NOT fast. Fortunately, the numbers are fairly small.
1670 factors = TosaArgGen.getFactors(totalElements)
1671
1672 for p in range(testGen.args.num_rand_permutations):
1673 newRank = testGen.randInt(1, 7)
1674 if len(factors) < newRank:
1675 continue
1676
1677 found = True
1678 # escape_counter breaks while loop if it continues on for too long
1679 escape_counter = 0
1680 while found:
1681 newShape = []
1682 # Generate newShape ensuring it isn't a duplicate
1683 remainingElements = totalElements
1684 shuffledFactors = testGen.rng.permutation(factors)
1685 for i in range(1, newRank):
1686 # pick rank-1 factors
1687 newShape.append(shuffledFactors[0])
1688 remainingElements = remainingElements // shuffledFactors[0]
1689 shuffledFactors = testGen.rng.permutation(
1690 TosaArgGen.getFactors(remainingElements)
1691 )
1692 newShape.append(remainingElements)
1693
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001694 # Check for duplicates
1695 found = False
1696 for name, other_shape in arg_list:
1697 if other_shape[0] == newShape:
1698 found = True
1699 break
1700
1701 escape_counter += 1
1702 if escape_counter >= 100:
1703 break
1704
1705 if not found:
1706 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
1707
1708 return arg_list
1709
1710 @staticmethod
1711 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
1712 arg_list = []
1713
1714 ifm_shape = shapeList[0]
1715
1716 if error_name == ErrorIf.IndexOutsideBounds:
1717 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
1718 incorrect_small_index = range(-len(ifm_shape), 0)
1719 permutations = [p for p in itertools.permutations(incorrect_large_index)]
1720 permutations.extend(
1721 [p for p in itertools.permutations(incorrect_small_index)]
1722 )
1723 elif error_name == ErrorIf.IndexUsedTwice:
1724 # Create list with a duplicated index
1725 perm_range = list(range(len(ifm_shape)))
1726 index_choice = testGen.rng.choice(range(len(perm_range)))
1727 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
1728 permutations = [p for p in itertools.permutations(perm_range)]
1729
1730 else:
1731 # Get all permutations
1732 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
1733
1734 # Limit to possible permutations from shape dimension or argument setting
1735 limit = min(len(permutations), testGen.args.num_rand_permutations)
1736
1737 # Get random permutation generator that uses all permutations
1738 random_permutations = testGen.rng.permutation(permutations)
1739
1740 # Create list of required amount of permutations
1741 arg_list = [
1742 ("perm{}".format(p), [random_permutations[p].tolist()])
1743 for p in range(limit)
1744 ]
1745 return arg_list
1746
1747 @staticmethod
1748 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
1749 arg_list = []
1750
1751 ifm_shape = shapeList[0]
1752 rank = len(ifm_shape)
1753
1754 for p in range(testGen.args.num_rand_permutations):
1755 start = []
1756 size = []
1757
1758 valid = True
1759
1760 for i in range(rank):
1761 if ifm_shape[i] > 1:
1762 start.append(testGen.randInt(0, ifm_shape[i]))
1763 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
1764
1765 # Invalid slice size?
1766 if size[i] == 0:
1767 valid = False
1768 else:
1769 start.append(0)
1770 size.append(1)
1771
1772 if valid:
1773 # If ERROR_IF test required then incorrect start, size will be returned
1774 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
1775 testGen, error_name, ifm_shape, start, size
1776 )
1777 arg_list.append(("perm{}".format(p), [start, size]))
1778 return arg_list
1779
1780 @staticmethod
1781 def agTile(testGen, opName, shapeList, dtype, error_name=None):
1782 arg_list = []
1783
1784 ifm_shape = shapeList[0]
1785 rank = len(ifm_shape)
1786
1787 for p in range(testGen.args.num_rand_permutations):
1788
1789 # Pick a few random, but small multiple values
1790 # because otherwise this has a tendency to generate
1791 # enormous tensors
1792 multiples = []
1793 for i in range(rank):
1794 if ifm_shape[i] > 1000:
1795 # Multiple of 1 if ifm_shape dimension is large to reduce
1796 # tensor size
1797 multiples.append(1)
1798 elif max(ifm_shape) > 1000:
1799 multiples.append(2)
1800 else:
1801 multiples.append(testGen.randInt(1, 4))
1802 arg_list.append(("perm{}".format(p), [multiples]))
1803
1804 return arg_list
1805
1806 @staticmethod
1807 def agResize(testGen, opName, shapeList, dtype, error_name=None):
1808 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001809 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001810
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001811 def get_aspect_ratio_resize_params():
1812 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
1813 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
1814 invert = testGen.rng.choice((False, True))
1815 letterbox = testGen.rng.choice((False, True))
1816
1817 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
1818 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
1819 scale_y_d = scale_x_d = 1
1820 offset_x = offset_y = 0
1821
1822 if letterbox:
1823 max_border = scale_y_n
1824 border_y = testGen.randInt(low=0, high=max_border)
1825 border_x = 0
1826 else:
1827 # Pillarboxing
1828 border_y = 0
1829 max_border = scale_x_n
1830 border_x = testGen.randInt(low=0, high=max_border)
1831
1832 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1833 offset = (offset_y, offset_x)
1834 border = (border_y, border_x)
1835
1836 return scale, offset, border
1837
1838 def get_upscale_downscale_params():
1839 valid_params = False
1840 while not valid_params:
1841 upscale = testGen.rng.choice((False, True))
1842
1843 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
1844 origin_sampling = testGen.rng.choice((False, True))
1845
1846 if upscale:
1847 shift = testGen.randInt(low=1, high=4)
1848 scale_x_d = scale_y_d = 1
1849 scale_x_n = scale_y_n = (
1850 1 << shift if origin_sampling else 2 << shift
1851 )
1852 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
1853 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
1854 else:
1855 scale_x_n = 1
1856 scale_y_n = 1
1857
1858 # Return list of valid scale_*_d values (max value 4) given input dim shape
1859 def get_valid_denom(ifm_dim):
1860 return [x for x in range(1, 5) if ifm_dim % x == 1]
1861
1862 # Generate list of valid downscale values and choose one randomly
1863 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
1864 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
1865
1866 if not valid_scale_y_ds and not valid_scale_x_ds:
1867 # Bad parameters, skip
1868 continue
1869
1870 if not valid_scale_y_ds:
1871 scale_y_d = 1
1872 else:
1873 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
1874
1875 if not valid_scale_x_ds:
1876 scale_x_d = 1
1877 else:
1878 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
1879
1880 border_x = border_y = 0
1881 offset_y = testGen.randInt(0, 16 * scale_y_n)
1882 offset_x = testGen.randInt(0, 16 * scale_x_n)
1883 valid_params = True
1884
1885 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1886 offset = (offset_y, offset_x)
1887 border = (border_y, border_x)
1888 return scale, offset, border
1889
1890 def get_rand_params():
1891 # Scale
1892 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
1893 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
1894
1895 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
1896 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
1897
1898 # Offsets and border within the scale
1899 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
1900 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
1901 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
1902 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
1903
1904 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1905 offset = (offset_y, offset_x)
1906 border = (border_y, border_x)
1907 return scale, offset, border
1908
1909 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001910 # Exclude illegal {mode, type} configurations. Pick legal output types
1911 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1912 outputDTypeList = [DType.INT8]
1913 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1914 outputDTypeList = [DType.INT16]
1915 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1916 outputDTypeList = [DType.INT32]
1917 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1918 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01001919 elif dtype == DType.FP16:
1920 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01001921 elif dtype == DType.BF16:
1922 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001923 elif dtype == DType.FP32:
1924 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001925 elif error_name == ErrorIf.WrongInputType:
1926 # If an incorrect input type is used then we set a 'correct'
1927 # output type to avoid other errors
1928 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
1929 else:
1930 continue
1931
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001932 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
1933
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001934 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001935 perm = 0
1936 while perm < testGen.args.num_rand_permutations:
1937 # Random choice of type of params we are testing
1938 _rnd_param_fn = testGen.rng.choice(
1939 (
1940 get_rand_params,
1941 get_upscale_downscale_params,
1942 get_aspect_ratio_resize_params,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001943 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001944 )
1945 scale, offset, border = _rnd_param_fn()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001946
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001947 # Expand params for bounds-checking
1948 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
1949 (offset_y, offset_x) = offset
1950 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001951
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001952 # Make sure output dimensions OH and OW are integers
1953 partial_output_y = (
1954 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
1955 )
1956 partial_output_x = (
1957 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
1958 )
1959 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
1960 if (
1961 partial_output_y % scale_y_d == 0
1962 and partial_output_x % scale_x_d == 0
1963 ):
1964 # Skip this test as it doesn't produce NonInteger output
1965 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001966 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001967 else:
1968 while partial_output_y % scale_y_d != 0:
1969 scale_y_d -= 1
1970 while partial_output_x % scale_x_d != 0:
1971 scale_x_d -= 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001972
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001973 output_y = partial_output_y // scale_y_d + 1
1974 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001975
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001976 if (
1977 output_y >= testGen.args.max_resize_output_dim
1978 or output_x >= testGen.args.max_resize_output_dim
1979 ) and error_name is None:
1980 # Skip positive test if output dim will be too high
1981 # Avoid high test latency and OOM issues
1982 perm += 1
1983 continue
1984
1985 if (
1986 output_y <= 0
1987 or output_y >= MAX_RESIZE_DIMENSION
1988 or output_x <= 0
1989 or output_x >= MAX_RESIZE_DIMENSION
1990 ):
1991 # Output dimensions out of scope
1992 if error_name is not None and perm > 0:
1993 # As long as we have one ERROR_IF test, don't worry
1994 # about creating all the other permutations
1995 perm += 1
1996 continue
1997
1998 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
1999 (
2000 output_y + scale_y_d >= MAX_RESIZE_DIMENSION
2001 and output_y - scale_y_d < 1
2002 )
2003 or (
2004 output_x + scale_x_d >= MAX_RESIZE_DIMENSION
2005 and output_x - scale_x_d < 1
2006 )
2007 ):
2008 # Can't create a negative test with these params as it
2009 # will create invalid output size
2010 if perm > 0:
2011 perm += 1
2012 continue
2013
2014 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
2015 offset = [offset_y, offset_x]
2016 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002017
2018 # Common for all data types
2019 if error_name is not None:
2020 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002021 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002022 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002023 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002024 outputDTypeNew,
2025 ) = TosaErrorIfArgGen.eiResizeErrorIf(
2026 testGen,
2027 error_name,
2028 mode,
2029 dtype,
2030 shapeList,
2031 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002032 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002033 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002034 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002035 )
2036 else:
2037 outputDTypeNew = outputDType
2038
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002039 arg_to_append = (
2040 arg_str.format(
2041 "N" if mode == ResizeMode.NEAREST else "B",
2042 testGen.typeStr(outputDTypeNew),
2043 scale[0],
2044 scale[1],
2045 scale[2],
2046 scale[3],
2047 offset[0],
2048 offset[1],
2049 border[0],
2050 border[1],
2051 ),
2052 [
2053 mode,
2054 scale,
2055 offset,
2056 border,
2057 dtype,
2058 outputDTypeNew,
2059 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002060 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002061 if arg_to_append in arg_list:
2062 # Skip already generated test params
2063 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002064
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002065 # Valid permutation
2066 perm += 1
2067 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002068 return arg_list
2069
2070 @staticmethod
2071 def agTable(testGen, opName, shapeList, dtype, error_name=None):
2072 arg_list = []
2073
2074 if dtype == DType.INT8:
2075 table = np.int32(
2076 testGen.rng.integers(low=-128, high=128, size=[256])
2077 ).tolist()
2078 else: # INT16
2079 table = np.int32(
2080 testGen.rng.integers(low=-32768, high=32768, size=[513])
2081 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07002082 # Make sure all slopes are within REQUIRE min/max 16-bit int
2083 for idx in range(len(table) - 1):
2084 slope = table[idx + 1] - table[idx]
2085 # Alter the next table entry to force the slope to be ok
2086 if slope > 32767:
2087 table[idx + 1] -= slope - 32767
2088 if slope < -32768:
2089 table[idx + 1] -= slope + 32768
2090 slope = table[idx + 1] - table[idx]
2091 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002092 arg_list.append(
2093 (
2094 "",
2095 [table],
2096 )
2097 )
2098 return arg_list
2099
2100 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
2101 # CondIf generates the condition values here.
2102 # Convert to tensors in the build function, along with the
2103 # then and else blocks
2104 arg_list = []
2105
2106 for c in [False, True]:
2107 arg_list.append(("cond{}".format(int(c)), [c]))
2108
2109 return arg_list
2110
2111 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
2112 # While loop: 0 iterations, 1, more than 1
2113 arg_list = []
2114
2115 for iter in [0, 1, 4]:
2116 arg_list.append(("iter{}".format(iter), [iter]))
2117
2118 return arg_list