blob: 370570c51965c54c68bcb3f961ccc77dc379f11e [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_error_if import ErrorIf
9from generator.tosa_error_if import TosaErrorIfArgGen
James Ward8b390432022-08-12 20:48:56 +010010from generator.tosa_utils import get_accum_dtype_from_tgTypes
11from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010012from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from serializer.tosa_serializer import DTypeNames
14from tosa.DType import DType
15from tosa.Op import Op
16from tosa.ResizeMode import ResizeMode
17
18# DTypeNames, DType, Op and ResizeMode are convenience variables to the
19# flatc-generated types that should be enums, but aren't
20
21
22class TosaQuantGen:
23 """QuantizedInfo random generator helper functions.
24
25 Specify with 'qgen': in the operator defintion.
26 """
27
28 def __init__(self):
29 pass
30
31 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000032 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010033
34 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010035 if testGen.args.zeropoint is not None:
36 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010037 return testGen.randInt(-128, 128)
38 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010039 if testGen.args.zeropoint is not None:
40 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010041 return testGen.randInt(0, 256)
42 elif error_name in [
43 ErrorIf.InputZeroPointNotZero,
44 ErrorIf.WeightZeroPointNotZero,
45 ErrorIf.OutputZeroPointNotZero,
46 ]:
47 zero_point = testGen.randInt(-128, 128)
48 if zero_point == 0:
49 zero_point = 1
50 return zero_point
51 return 0
52
53 @staticmethod
54 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010055 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000056 qinfo = [
57 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
58 TosaQuantGen.getZeroPoint(testGen, dtype),
59 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010060 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000061 qinfo = [
62 TosaQuantGen.getZeroPoint(testGen, dtype),
63 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
64 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010065 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000066 qinfo = [
67 TosaQuantGen.getZeroPoint(testGen, dtype),
68 TosaQuantGen.getZeroPoint(testGen, dtype),
69 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010070 return qinfo
71
72 @staticmethod
73 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010074 if isinstance(dtype_or_dtypeList, list):
75 # a list of [input, weights, accumulator] dtypes
76 dtypeList = dtype_or_dtypeList
77 else:
78 # an int, [input, weights, accumulator] dtypes are the same
79 dtypeList = [dtype_or_dtypeList] * 3
80
81 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000082 qinfo = [
83 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
84 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
85 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010086 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000087 qinfo = [
88 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
89 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
90 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010091 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000092 qinfo = [
93 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
94 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
95 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010096 return qinfo
97
98 @staticmethod
99 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100100 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000101 qinfo = [
102 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
103 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
104 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100105 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000106 qinfo = [
107 TosaQuantGen.getZeroPoint(testGen, dtype),
108 TosaQuantGen.getZeroPoint(testGen, dtype),
109 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100110 return qinfo
111
112 @staticmethod
113 def computeMultiplierAndShift(scaleFp, scale32):
114 # Derived from computeMultiplierAndShiftTosaScale32
115 # Provide a floating-point scaling factor and the scale32 parameter
116 # to compute the multiplier and shift
117
118 if scale32:
119 scaleBits = 31
120 else:
121 scaleBits = 15
122
123 m, shift = math.frexp(scaleFp)
124
125 if scaleFp < 0.0:
126 m = -m
127
128 multiplier = round(m * (1 << scaleBits))
129 assert multiplier <= (1 << scaleBits)
130
131 if multiplier == (1 << scaleBits):
132 multiplier = multiplier // 2
133 shift = shift + 1
134
135 shift = (-shift) + scaleBits
136 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
137 # scaleFp, scaleBits, m, multiplier, shift))
138
139 # Adjust multiplier such that shift is in allowed value range.
140 if shift == 0:
141 multiplier = multiplier // 4
142 shift = shift + 2
143 elif shift == 1:
144 multiplier = multiplier // 2
145 shift = shift + 1
146 elif shift == 63:
147 multiplier = multiplier * 2
148 shift = shift - 1
149
150 assert multiplier <= (1 << scaleBits)
151 assert shift >= 2 and shift <= 62
152
153 return multiplier, shift
154
155
156class TosaTensorGen:
157 """Tensor generators create a shape list for the placeholder and const tensor
158 data operands for the operator.
159
160 The actual random data is generated separately for each test.
161 """
162
163 def __init__(self):
164 pass
165
166 @staticmethod
167 def tgBasic(testGen, opName, rank, error_name=None):
168 pl, const = opName["operands"]
169 shape = testGen.makeShape(rank)
170
171 # Constrict the overall size of the shape when creating ERROR_IF tests
172 if error_name:
173 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
174
175 shape_list = []
176 for i in range(pl + const):
177 shape_list.append(shape.copy())
178
179 if error_name == ErrorIf.RankMismatch:
180 if rank == 1 and i != 1:
181 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
182 elif i != 1:
183 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
184
185 return shape_list
186
187 @staticmethod
188 def tgNHWC(testGen, opName, rank, error_name=None):
189 pl, const = opName["operands"]
190
191 if error_name != ErrorIf.WrongRank:
192 assert rank == 4
193
194 shape = testGen.makeShape(rank)
195
196 # Constrict the batch size?
197 if testGen.args.max_batch_size:
198 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
199
200 # Constrict the overall size of the shape when creating ERROR_IF tests
201 if error_name and error_name != ErrorIf.MaxDimExceeded:
202 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
203
204 shape_list = []
205 for i in range(pl + const):
206 shape_list.append(shape.copy())
207
208 return shape_list
209
210 @staticmethod
211 def tgScatter(testGen, opName, rank, error_name=None):
212 pl, const = opName["operands"]
213
214 assert pl == 2
215 assert const == 0
216 if error_name != ErrorIf.WrongRank:
217 assert rank == 3
218
219 values_in_shape = testGen.makeShape(rank)
220
221 # ignore max batch size if target shape is set
222 if testGen.args.max_batch_size and not testGen.args.target_shapes:
223 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
224
225 W = testGen.randInt(
226 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
227 )
228 # Constrict W if one dimension is too large to keep tensor size reasonable
229 if max(values_in_shape) > 5000:
230 W = testGen.randInt(0, 16)
231
232 input_shape = [values_in_shape[0], W, values_in_shape[2]]
233
234 shape_list = []
235 shape_list.append(values_in_shape.copy())
236 shape_list.append(input_shape.copy())
237
238 return shape_list
239
240 @staticmethod
241 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
242 shape = testGen.makeShape(rank)
243
244 pl, const = op["operands"]
245
246 shape_list = []
247
248 # Choose one of the inputs to broadcast
249 # Note: Simplifies OutputShaper code if we don't change first shape for errors
250 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
251 for i in range(pl + const):
252 shape_bcast = shape.copy()
253
254 # If the chosen input, pick a random index to broadcast
255 if i == bcast_idx:
256 fuzz_idx = testGen.randInt(0, rank)
257 if error_name == ErrorIf.DimensionMismatch:
258 shape_bcast[fuzz_idx] += 1
259 elif error_name == ErrorIf.RankMismatch:
260 # Add one rank to the shape (or more for rank of 1)
261 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
262 shape_bcast = np.concatenate(
263 (shape_bcast, testGen.makeShape(extra_ranks))
264 )
265 if rank != 1:
266 # Either keep the extra rank, or remove it
267 new_len = testGen.rng.choice([-2, len(shape_bcast)])
268 shape_bcast = shape_bcast[:new_len]
269 else:
270 shape_bcast[fuzz_idx] = 1
271
272 shape_list.append(shape_bcast)
273
274 return shape_list
275
276 @staticmethod
277 def tgConv2D(testGen, op, rank, error_name=None):
278 pl, const = op["operands"]
279
280 if error_name != ErrorIf.WrongRank:
281 assert rank == 4
282
283 # IFM dimensions are NHWC
284 ifm_shape = testGen.makeShape(rank)
285
286 # Constrict the batch size?
287 if testGen.args.max_batch_size:
288 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
289
290 # Constrict the overall size of the shape when creating ERROR_IF tests
291 if error_name:
292 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
293 ifm_shape, max_dim=24, max_items=10000
294 )
295
296 # Get the filter height/width from the operator parameters
297 filter_hw = op["filter"]
298
299 # Generate a random OFM depth
300 ofm_depth = testGen.makeShape(1)[0]
301
302 # The filter dimensions are OHWI
303 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
304
305 # The bias is OC
306 bias_shape = np.asarray([ofm_depth])
307
308 return [ifm_shape, filter_shape, bias_shape]
309
310 @staticmethod
311 def tgConv3D(testGen, op, rank, error_name=None):
312 pl, const = op["operands"]
313
314 if error_name != ErrorIf.WrongRank:
315 assert rank == 5
316
317 # IFM dimensions are NDHWC
318 ifm_shape = testGen.makeShape(rank)
319
320 # Constrict the batch size?
321 if testGen.args.max_batch_size:
322 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
323
324 # Constrict the overall size of the shape when creating ERROR_IF tests
325 if error_name:
326 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
327 ifm_shape, max_dim=24, max_items=10000
328 )
329
330 # Get the filter depth/height/width from the operator parameters
331 filter_dhw = op["filter"]
332
333 # Generate a random OFM channel
334 ofm_channel = testGen.makeShape(1)[0]
335
336 # The filter dimensions are ODHWI
337 filter_shape = np.asarray(
338 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
339 )
340
341 # The bias is OC
342 bias_shape = np.asarray([ofm_channel])
343
344 return [ifm_shape, filter_shape, bias_shape]
345
346 @staticmethod
347 def tgTransposeConv2D(testGen, op, rank, error_name=None):
348 pl, const = op["operands"]
349
350 if error_name != ErrorIf.WrongRank:
351 assert rank == 4
352
353 # IFM dimensions are NHWC
354 ifm_shape = testGen.makeShape(rank)
355
356 # Constrict the batch size?
357 if testGen.args.max_batch_size:
358 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
359
360 # Constrict the overall size of the shape when creating ERROR_IF tests
361 if error_name:
362 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
363 ifm_shape, max_dim=24, max_items=10000
364 )
365
366 # Get the filter height/width from the operator parameters
367 filter_hw = op["filter"]
368
369 # Generate a random OFM depth
370 ofm_depth = testGen.makeShape(1)[0]
371
372 # The filter dimensions are OHWI
373 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
374
375 # The bias is OC
376 bias_shape = np.asarray([ofm_depth])
377
378 return [ifm_shape, filter_shape, bias_shape]
379
380 @staticmethod
381 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
382 pl, const = op["operands"]
383
384 if error_name != ErrorIf.WrongRank:
385 assert rank == 4
386 assert pl == 1 and const == 2
387
388 # IFM dimensions are NHWC
389 ifm_shape = testGen.makeShape(rank)
390
391 # Constrict the batch size?
392 if testGen.args.max_batch_size:
393 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
394
395 # Constrict the overall size of the shape when creating ERROR_IF tests
396 if error_name:
397 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
398 ifm_shape, max_dim=24, max_items=10000
399 )
400
401 # Get the filter height/width from the operator parameters
402 # Filter is KH, HW, C, M
403 filter_hw = op["filter"]
404
405 # Generate a random OFM depth, but don't let it get too big because
406 # the output depth is M * C
407 filter_m = (
408 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
409 ) + 1
410
411 # The filter dimensions are HWCM
412 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
413
414 # The bias is M * C
415 bias_shape = np.asarray([ifm_shape[3] * filter_m])
416
417 return [ifm_shape, filter_shape, bias_shape]
418
419 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000420 def tgFFT2d(testGen, op, rank, error_name=None):
421 pl, const = op["operands"]
422
423 if error_name != ErrorIf.WrongRank:
424 assert rank == 3
425 assert pl == 2 and const == 0
426
427 # IFM dimensions are NHW
428 ifm_shape = testGen.makeShape(rank)
429
430 # Select nearest lower power of two from input height and width
431 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
432 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
433
434 # Constrict the overall size of the shape when creating ERROR_IF tests
435 if error_name:
436 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
437
438 # Generate an invalid kernel that is not a power of two
439 if error_name == ErrorIf.KernelNotPowerOfTwo:
440 inc_h = 2 if ifm_shape[1] == 1 else 1
441 inc_w = 2 if ifm_shape[2] == 1 else 1
442 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
443 selected_inc = testGen.rng.choice(inc_choices)
444 ifm_shape[1] += selected_inc[0]
445 ifm_shape[2] += selected_inc[1]
446
447 ifm_shape = testGen.constrictBatchSize(ifm_shape)
448
449 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
450 if error_name == ErrorIf.FFTInputShapeMismatch:
451 modify_shape = testGen.rng.choice([0, 1])
452 # Only modify kernel (H, W)
453 modify_dim = testGen.rng.choice([1, 2])
454 ifm_shapes[modify_shape][modify_dim] *= 2
455
456 return [ifm_shapes[0], ifm_shapes[1]]
457
458 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000459 def tgRFFT2d(testGen, op, rank, error_name=None):
460 pl, const = op["operands"]
461
462 if error_name != ErrorIf.WrongRank:
463 assert rank == 3
464 assert pl == 1 and const == 0
465
466 # IFM dimensions are NHW
467 ifm_shape = testGen.makeShape(rank)
468
469 # Select nearest lower power of two from input height and width
470 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
471 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
472
473 # Constrict the overall size of the shape when creating ERROR_IF tests
474 if error_name:
475 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
476
477 # Generate an invalid kernel that is not a power of two
478 if error_name == ErrorIf.KernelNotPowerOfTwo:
479 # We must increment by 2 if current size is 1
480 inc_h = 2 if ifm_shape[1] == 1 else 1
481 inc_w = 2 if ifm_shape[2] == 1 else 1
482 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
483 selected_inc = testGen.rng.choice(inc_choices)
484 ifm_shape[1] += selected_inc[0]
485 ifm_shape[2] += selected_inc[1]
486
487 # Constrict the batch size
488 if testGen.args.max_batch_size:
489 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
490
491 return [ifm_shape]
492
493 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100494 def tgFullyConnected(testGen, op, rank, error_name=None):
495 pl, const = op["operands"]
496
497 if error_name != ErrorIf.WrongRank:
498 assert rank == 2
499
500 input_shape = testGen.makeShape(rank)
501
502 # Constrict the overall size of the shape when creating ERROR_IF tests
503 if error_name:
504 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
505
506 filter_oc = testGen.rng.integers(
507 low=testGen.args.tensor_shape_range[0],
508 high=testGen.args.tensor_shape_range[1],
509 size=1,
510 )[0]
511 filter_shape = np.asarray([filter_oc, input_shape[1]])
512
513 bias_shape = np.asarray([filter_oc])
514
515 return [input_shape, filter_shape, bias_shape]
516
517 @staticmethod
518 def tgMatmul(testGen, op, rank, error_name=None):
519 pl, const = op["operands"]
520
521 if error_name != ErrorIf.WrongRank:
522 assert rank == 3
523 assert pl == 2 and const == 0
524
525 a_shape = testGen.makeShape(rank)
526
527 # Constrict the overall size of the shape when creating ERROR_IF tests
528 if error_name:
529 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
530
531 # Get a random number for b_oc even if target shape is defined
532 b_oc = np.int32(
533 testGen.rng.integers(
534 low=testGen.args.tensor_shape_range[0],
535 high=testGen.args.tensor_shape_range[1],
536 size=1,
537 )
538 )[0]
539 # If N or H is large let b_oc be 1 to reduce output tensor size
540 if max(a_shape) > 1000:
541 b_oc = 1
542
543 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
544 return [a_shape, b_shape]
545
546 @staticmethod
547 def tgConcat(testGen, opName, rank, error_name=None):
548 pl, const = opName["operands"]
549 shape = testGen.makeShape(rank)
550
551 # Create extra tensors to concat.
552 # Take into account value of pl when getting maximum number of concats
553 num_tensors = testGen.randInt(0, 4)
554 shape_list = []
555 for i in range(pl + const + num_tensors):
556 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
557 remove = testGen.rng.choice([True, False])
558 wrongShape = shape.copy()
559
560 if remove and len(shape) > 1:
561 wrongShape = wrongShape[1:]
562 else:
563 wrongShape = list(wrongShape)
564 wrongShape.append(testGen.rng.integers(1, 10))
565
566 shape_list.append(wrongShape)
567 else:
568 shape_list.append(shape.copy())
569
570 return shape_list
571
572 @staticmethod
573 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
574 if error_name in [
575 ErrorIf.AxisSmallerZero,
576 ErrorIf.AxisLargerRank,
577 ErrorIf.ConcatInputRankMismatch,
578 ]:
579 return shapeList
580
581 # Split concat shape along axis to allow for multiple const inputs
582 # without making too many large tensors
583 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
584 # If axis can't be split we still need to invalidate other dimensions
585 if error_name == ErrorIf.ConcatInputDimMismatch:
586 for shape in shapeList[1:]:
587 # Negative test shapeLists are created individually for each test,
588 # so no need to copy the shape before altering it.
589 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
590 return shapeList
591
592 # Create copy of shape we are going to split (so we don't alter shapeList)
593 shape = shapeList[0].copy()
594 # Add original shape as first input
595 new_shapeList = [shape.copy()]
596 length_on_axis = shape[axis]
597 remaining_length = length_on_axis
598 for i in range(len(shapeList) - 2):
599 # Calculate split on axis and remaining value
600 split_shape_val = int(shape[axis] / 2)
601 remaining_length = remaining_length - split_shape_val
602
603 # Append new shape, and set remaining shape
604 shape[axis] = split_shape_val
605 new_shapeList.append(shape.copy())
606
607 # invalidate dimensions
608 if error_name == ErrorIf.ConcatInputDimMismatch:
609 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
610 else:
611 shape[axis] = remaining_length
612
613 if i == len(shapeList) - 3:
614 new_shapeList.append(shape.copy())
615
616 return new_shapeList
617
618
619class TosaTensorValuesGen:
620 """Tensor Value generators create the random data for each test."""
621
622 def __init__(self):
623 pass
624
625 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000626 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100627 pCount, cCount = op["operands"]
628
629 tens = []
630 tens.extend(
631 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
632 )
633 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
634
635 return tens
636
637 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000638 def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100639 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100640 pCount, cCount = op["operands"]
641 assert (
642 pCount == 1 and cCount == 0
643 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100644 # Must create tensors with values within accumulator (int32) negatable
645 # range
646 max_val = (1 << 31) - 1
647 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100648 arr = np.int32(
649 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
650 )
651 placeholders = []
652 placeholders.append(
653 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
654 )
655 return placeholders
656 else:
657 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000658 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100659 )
660
661 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000662 def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100663 if dtypeList[0] == DType.INT32 and error_name is None:
664 # Make sure the operation does not cause value saturation - where
665 # the number wraps due to limited number of bits to store the answer
666 pCount, cCount = op["operands"]
667 assert (
668 pCount == 2 and cCount == 0
669 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
670 placeholders = []
671 add = op["op"] == Op.ADD
672 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
673 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
674 if add:
675 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
676 else:
677 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
678
679 # Work out the saturation limits
680 max_i32 = (1 << 31) - 1
681 min_i32 = -(1 << 31)
682 max_arr = np.full(shapeList[1], max_i32)
683 min_arr = np.full(shapeList[1], min_i32)
684
685 # Find how much values exceed the maximum/minimums
686 sat_max_arr = np.maximum(res_arr - max_arr, 0)
687 sat_min_arr = np.minimum(res_arr - min_arr, 0)
688
689 if not add:
690 # Swap saturation values and negate values as we need to perform opposite operations
691 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
692
693 # Create new array of unsaturated values by clipping values as needed
694 b_unsat_arr = b_arr
695 if (sat_max_arr != 0).any():
696 # Clip values that cause saturation
697 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
698 # Reduce axes in unsaturated tensor to match original tensor
699 for axis, dim in enumerate(b_arr.shape):
700 if dim != b_unsat_arr.shape[axis]:
701 assert (
702 dim == 1
703 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
704 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
705
706 if (sat_min_arr != 0).any():
707 # Clip values that cause saturation
708 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
709 # Reduce axes in unsaturated tensor to match original tensor
710 for axis, dim in enumerate(b_arr.shape):
711 if dim != b_unsat_arr.shape[axis]:
712 assert (
713 dim == 1
714 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
715 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
716
717 placeholders.append(
718 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
719 )
720 placeholders.append(
721 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
722 )
723
724 return placeholders
725 else:
726 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000727 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100728 )
729
730 @staticmethod
731 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000732 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100733 ):
734 if dtypeList[0] in (
735 DType.INT32,
736 DType.INT16,
737 DType.INT8,
738 ):
739 # Limit input tensors with cond_if_binary or while_loop to stop
740 # saturation of add/sub ops with int32 and keep all logical shift
741 # values between 0 to 31 for int16 or int8
742 pCount, cCount = op["operands"]
743 pRemain = pCount
744 placeholders = []
745 for idx, shape in enumerate(shapeList[:]):
746 if dtypeList[0] == DType.INT32:
747 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
748 else:
749 arr = np.int32(
750 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
751 )
752 if pRemain > 0:
753 placeholders.append(
754 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
755 )
756 pRemain -= 1
757 else:
758 placeholders.append(
759 testGen.ser.addConst(shape, dtypeList[idx], arr)
760 )
761
762 return placeholders
763 else:
764 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000765 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100766 )
767
768 @staticmethod
769 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000770 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100771 ):
772 pCount, cCount = op["operands"]
773 # Force value of operand[1] to be within [0, num_bits]
774 assert (
775 pCount == 2 and cCount == 0
776 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
777
778 placeholders = []
779 for idx, shape in enumerate(shapeList[:]):
780 if idx == 1:
781 if dtypeList[idx] == DType.INT8:
782 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
783 elif dtypeList[idx] == DType.INT16:
784 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
785 elif dtypeList[idx] == DType.INT32:
786 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
787 elif error_name == ErrorIf.WrongInputType:
788 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
789 else:
790 raise Exception("OpArithmeticRightShift: invalid input dtype")
791 else:
792 arr = testGen.getRandTensor(shape, dtypeList[idx])
793 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
794
795 return placeholders
796
797 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000798 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100799 # Set datatype of condition tensor to boolean
800 dtypeList[0] = DType.BOOL
801
802 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000803 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100804 )
805
806 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000807 def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100808 if error_name is None:
809 pCount, cCount = op["operands"]
810 assert (
811 pCount == 2 and cCount == 0
812 ), "Op.INTDIV must have 2 placeholders, 0 consts"
813
814 placeholders = []
815
816 # Two invalid cases for Op.INTDIV:
817 # 1. divisor == 0
818 # 2. dividend == -(1<<31) and divisor == -1
819 while True:
820 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
821 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
822
823 if (divisor_arr == 0).any():
824 continue
825
826 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
827 continue
828
829 break
830
831 placeholders.append(
832 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
833 )
834 placeholders.append(
835 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
836 )
837
838 return placeholders
839 else:
840 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000841 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100842 )
843
844 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000845 def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100846 if error_name is None:
847 pCount, cCount = op["operands"]
848 assert (
849 pCount == 2 and cCount == 0
850 ), "Op.MUL must have 2 placeholders, 0 consts"
851
852 tens = []
James Ward24dbc422022-10-19 12:20:31 +0100853 if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100854 tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
855 else:
856 placeholders = []
857
858 # Make sure multiply result in int32 range
859 shift = testArgs[0]
860 if dtypeList[0] == DType.INT8:
861 num_bits = 8
862 elif dtypeList[0] == DType.INT16:
863 num_bits = 16
864 elif dtypeList[0] == DType.INT32:
865 num_bits = 32
866 elif error_name == ErrorIf.WrongInputType:
867 num_bits = 8
868 else:
869 raise Exception("OpMul: invalid input dtype")
870
871 for idx, shape in enumerate(shapeList[:]):
872 low = -(2 ** (num_bits - 1))
873 high = (2 ** (num_bits - 1)) - 1
874
875 a_arr = np.int32(
876 testGen.rng.integers(low=low, high=high, size=shapeList[0])
877 )
878 b_arr = np.int32(
879 testGen.rng.integers(low=low, high=high, size=shapeList[1])
880 )
881
882 i = 0
883 while True:
884
885 a_arr_64 = a_arr.astype(np.int64)
886 b_arr_64 = b_arr.astype(np.int64)
887
888 if shift > 0:
889 rounding = 1 << (shift - 1)
890 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
891 else:
892 result_arr = a_arr_64 * b_arr_64
893
894 if (result_arr > -(2**31)).all() and (
895 result_arr <= ((2**31) - 1)
896 ).all():
897 break
898
899 i = i + 1
900 a_arr = a_arr // 2
901 b_arr = b_arr // 2
902
903 placeholders.append(
904 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
905 )
906 placeholders.append(
907 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
908 )
909
910 tens.extend(placeholders)
911
912 return tens
913 else:
914 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000915 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100916 )
917
918 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000919 def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100920 count = len(shapeList) - testGen.args.num_const_inputs_concat
921 if count < 1:
922 count = 1
923 if testGen.args.num_const_inputs_concat == 0:
924 count = len(shapeList)
925
926 # Ensure axis is an int
927 testArgs[0] = int(testArgs[0])
928
929 shapeList = TosaTensorGen.tgConcatConstInput(
930 testGen, shapeList, testArgs[0], error_name
931 )
932
933 tens = []
934 tens.extend(
935 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
936 )
937 tens.extend(testGen.buildConstTensors(shapeList[count:], dtypeList[count:]))
938
939 return tens
940
941 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000942 def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100943 pCount, cCount = op["operands"]
944 assert (
945 pCount == 2 and cCount == 0
946 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
947 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
948 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
949 placeholders = []
950 placeholders.append(
951 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
952 )
953 placeholders.append(
954 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
955 )
956
957 return placeholders
958
959 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000960 def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100961 if error_name is None:
962 pCount, cCount = op["operands"]
963 assert (
964 pCount == 2 and cCount == 0
965 ), "Op.EQUAL must have 2 placeholders, 0 consts"
966 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
967 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
968 # Using random numbers means that it will be very unlikely that
969 # there are any matching (equal) values, therefore force that
970 # there are twice the number of matching values as the tensor rank
971 for num in range(0, len(shapeList[0]) * 2):
972 a_index = []
973 b_index = []
974 # Choose an index in each axis for the whole shape
975 for axis in range(0, len(shapeList[0])):
976 # Index can be up to the largest dimension in both shapes
977 index = np.int32(
978 testGen.rng.integers(
979 0, max(shapeList[0][axis], shapeList[1][axis])
980 )
981 )
982 # Reduce the index down to a shape's dim for broadcasting
983 a_index.append(min(shapeList[0][axis] - 1, index))
984 b_index.append(min(shapeList[1][axis] - 1, index))
985
986 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
987
988 placeholders = []
989 placeholders.append(
990 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
991 )
992 placeholders.append(
993 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
994 )
995 return placeholders
996 else:
997 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000998 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100999 )
1000
1001 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001002 def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001003 if dtypeList[0] == DType.INT32:
1004 pCount, cCount = op["operands"]
1005 assert (
1006 pCount == 1 and cCount == 0
1007 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1008 # Limit values so that the sum cannot exceed the range of an int32 during
1009 # summation of any axis
1010 range_val = int((1 << 31) / max(shapeList[0]))
1011 values_arr = np.int32(
1012 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1013 )
1014 placeholders = []
1015 placeholders.append(
1016 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1017 )
1018 return placeholders
1019 else:
1020 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001021 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001022 )
1023
1024
1025class TosaArgGen:
1026 """Argument generators create exhaustive or random lists of attributes for
1027 operators that take attributes or other parameters.
1028
1029 The return value is a list of (descriptive_name, [arglist]) tuples where
1030 the descriptive_name is appended to the test name and the arglist is expanded
1031 as arguments to the operator build function.
1032 """
1033
1034 def __init__(self):
1035 pass
1036
1037 @staticmethod
1038 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1039 """A trivial argument generator for operators that don't take any
1040 non-tensor arguments"""
1041 return [("", [])]
1042
1043 @staticmethod
1044 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1045 """Build the axis argument for operators that take a single axis"""
1046 axes = []
1047 shape = shapeList[0]
1048
1049 if error_name == ErrorIf.AxisSmallerZero:
1050 small_axis = testGen.rng.integers(-5, 0)
1051 axes.append(("axis{}".format(small_axis), [small_axis]))
1052 elif error_name == ErrorIf.AxisLargerRank:
1053 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
1054 axes.append(("axis{}".format(large_axis), [large_axis]))
1055 else:
1056 for a in range(0, len(shape)):
1057 axes.append(("axis{}".format(a), [a]))
1058
1059 return axes
1060
1061 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001062 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001063 arg_list = []
1064
1065 ifm_shape = shapeList[0]
1066 filter_shape = shapeList[1]
1067 # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
1068 k = [int(x) for x in opName.split("_")[-1].split("x")]
1069
James Ward8b390432022-08-12 20:48:56 +01001070 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1071
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001072 # Check the rank
1073 rank = 5 if opName.startswith("conv3d") else 4
1074 if error_name != ErrorIf.WrongRank:
1075 assert len(ifm_shape) == rank
1076 assert len(filter_shape) == rank
1077
1078 # kernel rank omits batch and channels
1079 k_rank = rank - 2
1080 assert len(k) == k_rank
1081
1082 # Generate comprehensive argument lists
1083 # - except for named errors, which use specific invalid value(s)
1084 if error_name == ErrorIf.PadSmallerZero:
1085 p_vals = [testGen.rng.choice(range(-5, 0))]
1086 else:
1087 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1088 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1089 if error_name == ErrorIf.StrideSmallerOne:
1090 # Can't use stride=0, as it is used to derive output shape, as a divisor
1091 s_vals = [testGen.rng.choice(range(-5, 0))]
1092 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001093 # Stride must be greater than 1 to force non-integer error
Jeremy Johnson93d43902022-09-27 12:26:14 +01001094 startStride = 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001095 s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001096 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1097 if error_name == ErrorIf.DilationSmallerOne:
1098 d_vals = [testGen.rng.choice(range(-5, 1))]
1099 else:
1100 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1101 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
1102
1103 if not error_name and testGen.args.oversize:
1104 # add some oversize argument values
1105 if max(ifm_shape) < 64:
1106 bigPadding = 9
1107 paddings.update(
1108 {x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))}
1109 )
1110 bigStride = 8
1111 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
1112 bigDilation = 7
1113 dilations.update(
1114 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1115 )
1116
1117 # There are too many parameter combinations, so generate them sparsely,
1118 # very sparse for negative tests
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001119 sparsity_factor = 2 if error_name else 120
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001120 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
1121 # If there are only a small number of tests, just select them all
1122 if sparsity < 13:
1123 sparsity = 1
1124 # To get a variety of parameter combinations sparsity should not be a
1125 # multiple of 2, 3 or 5
1126 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1127 sparsity += 1
1128
1129 n = 0
1130 for s in sorted(list(strides)):
1131 for p in sorted(list(paddings)):
1132 for d in sorted(list(dilations)):
1133 if (
1134 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001135 # the padded shape must exceed the dilation * kernel to get a positive
1136 # sized output shape
1137 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k[0] - 1)
1138 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k[1] - 1)
1139 and (
1140 k_rank < 3
1141 or ((ifm_shape[3] - 1 + p[4] + p[5]) > d[2] * (k[2] - 1))
1142 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001143 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001144 remainders = []
1145 for index in range(k_rank):
1146 pad_offset = index * 2
1147 remainders.append(
1148 (
1149 ifm_shape[index + 1]
1150 - 1
1151 + p[pad_offset]
1152 + p[pad_offset + 1]
1153 - (k[index] - 1) * d[index]
1154 )
1155 % s[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001156 )
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001157 if (
1158 # the parameters must produce integer exact output
1159 error_name != ErrorIf.ConvOutputShapeNonInteger
1160 and max(remainders) == 0
1161 ) or (
1162 error_name == ErrorIf.ConvOutputShapeNonInteger
1163 and max(remainders) > 0
1164 ):
1165 arg_list.append(
1166 (
James Ward8b390432022-08-12 20:48:56 +01001167 "acc{}_st{}_pad{}_dilat{}".format(
1168 testGen.typeStr(accum_dtype),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001169 "".join([str(x) for x in s]),
1170 "".join([str(x) for x in p]),
1171 "".join([str(x) for x in d]),
1172 ),
James Ward8b390432022-08-12 20:48:56 +01001173 [accum_dtype, s, p, d],
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001174 )
1175 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001176 n += 1
1177
1178 return arg_list
1179
1180 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001181 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1182
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001183 assert isinstance(dtypes, list) or isinstance(
1184 dtypes, tuple
1185 ), f"{dtypes} unexpected"
1186 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01001187
1188 if error_name == ErrorIf.WrongOutputType:
1189 accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype)
1190 elif error_name == ErrorIf.WrongInputType:
1191 # Pick some potentially correct output dtype if input type is incorrect
1192 accum_dtype = DType.INT32
1193 else:
1194 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1195
1196 return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])]
1197
1198 @staticmethod
1199 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
1200 # Get valid accumulate type(s)
1201 if dtype == DType.INT8:
1202 accum_dtypes = [DType.INT32]
1203 elif dtype == DType.INT16:
1204 accum_dtypes = [DType.INT48]
1205 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001206 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001207 elif dtype == DType.BF16:
1208 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001209 elif dtype == DType.FP32:
1210 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001211 elif error_name is None:
1212 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
1213
1214 if error_name == ErrorIf.WrongOutputType:
1215 # Get incorrect output dtype for ErrorIf case
1216 accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)]
1217 elif error_name == ErrorIf.WrongInputType:
1218 # Pick some potentially correct output dtype if input type is incorrect
1219 accum_dtypes = [DType.INT32]
1220
1221 return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes]
1222
1223 @staticmethod
1224 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001225 arg_list = []
1226
1227 ifm_shape = shapeList[0]
1228 filter_shape = shapeList[1]
1229
James Ward8b390432022-08-12 20:48:56 +01001230 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1231
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001232 # Must be rank 4
1233 if error_name != ErrorIf.WrongRank:
1234 assert len(ifm_shape) == 4
1235 assert len(filter_shape) == 4
1236
1237 # Generate comprehensive argument lists
1238 # - except for named errors, which use specific invalid value(s)
Eric Kunzec1a97832022-07-01 16:56:09 -07001239 smallest_padding_size = -min(filter_shape[1], filter_shape[2]) + 1
1240 if error_name == ErrorIf.PadLargerEqualKernel:
1241 max_filter_size = -max(filter_shape[1], filter_shape[2])
1242 p_vals = [testGen.rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001243 else:
Eric Kunzec1a97832022-07-01 16:56:09 -07001244 p_vals = [
1245 x
1246 for x in range(smallest_padding_size, testGen.args.max_conv_padding + 1)
1247 ]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001248 paddings = {x for x in itertools.product(*([p_vals] * 4))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001249 if error_name == ErrorIf.StrideSmallerOne:
1250 # Can't use stride=0, as it is used to derive output shape, as a divisor
1251 s_vals = [testGen.rng.choice(range(-5, 0))]
1252 else:
1253 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
1254 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001255
Jeremy Johnson5860df62022-05-04 15:30:58 +01001256 if not error_name and testGen.args.oversize:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001257 # add some oversize argument values
1258 if max(ifm_shape) < 64:
1259 bigPadding = 9
1260 paddings.update(
Eric Kunzec1a97832022-07-01 16:56:09 -07001261 {
1262 x
1263 for x in itertools.product(
1264 *([[smallest_padding_size, bigPadding]] * 4)
1265 )
1266 }
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001267 )
1268 bigStride = 8
1269 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001270
1271 # There are too many parameter combinations, so generate them sparsely,
1272 # very sparse for negative tests
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001273 sparsity_factor = 2 if error_name else 10
TatWai Chong24594f52022-06-08 00:48:04 -07001274 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001275 # If there are only a small number of tests, just select them all
1276 if sparsity < 13:
1277 sparsity = 1
1278 # To get a variety of parameter combinations sparsity should not be a
1279 # multiple of 2, 3 or 5
1280 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1281 sparsity += 1
1282
1283 n = 0
1284 for s in sorted(list(strides)):
1285 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07001286 if n % sparsity == 0:
1287 # Determine the output shape
Eric Kunzec1a97832022-07-01 16:56:09 -07001288 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + filter_shape[1]
1289 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + filter_shape[2]
TatWai Chong24594f52022-06-08 00:48:04 -07001290 os = [ifm_shape[0], oh, ow, filter_shape[0]]
1291 arg_list.append(
1292 (
James Ward8b390432022-08-12 20:48:56 +01001293 "acc{}_st{}_pad{}_os{}".format(
1294 testGen.typeStr(accum_dtype),
TatWai Chong24594f52022-06-08 00:48:04 -07001295 "".join([str(x) for x in s]),
1296 "".join([str(x) for x in p]),
1297 "x".join([str(x) for x in os]),
1298 ),
James Ward8b390432022-08-12 20:48:56 +01001299 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001300 )
TatWai Chong24594f52022-06-08 00:48:04 -07001301 )
1302 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001303
1304 return arg_list
1305
1306 @staticmethod
1307 def agPad(testGen, opName, shapeList, dtype, error_name=None):
1308 arg_list = []
1309 rank = len(shapeList[0])
1310
1311 # Exhaustively test combinations of padding on each side of each dimension
1312 # - the range of padding values is defined by pad_min and pad_max
1313 # - for padding >9, the name format needs to be more distinctive
1314 pad_min, pad_max = 0, 1
1315 pad_values = [x for x in range(pad_min, pad_max + 1)]
1316 if error_name == ErrorIf.PadSmallerZero:
1317 pad_values = [x for x in range(-2, 0)]
1318 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
1319 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
1320
1321 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
1322 pad_const_int = testGen.getRandNumberDType(dtype)
1323 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00001324 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001325 pad_const_int = 0
1326 pad_const_fp = testGen.getRandNumberDType(dtype)
1327 else:
1328 return []
1329
1330 for paddings in shape_pad_values:
James Ward8b390432022-08-12 20:48:56 +01001331 paddings = list(paddings)
1332 args_valid = True
1333
1334 if error_name == ErrorIf.PadSmallerZero:
1335 # Prevent negative output shapes while ensuring still testing for negative padding
1336 for i in range(rank):
1337 dim_after_padding = (
1338 paddings[i][0] + paddings[i][1] + shapeList[0][i]
1339 )
1340 if dim_after_padding < 1:
1341 paddings[i] = (0, 0)
1342 if all([p > -1 for p in paddings[i]]):
1343 args_valid = False
1344
1345 if args_valid:
1346 name = "pad"
1347 for r in range(rank):
1348 before, after = paddings[r]
1349 name = f"{name}{before}{after}"
1350 arg_list.append(
1351 (name, [np.array(paddings), pad_const_int, pad_const_fp])
1352 )
1353
1354 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
1355 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001356
1357 return arg_list
1358
1359 @staticmethod
1360 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
1361 arg_list = []
1362
1363 shape = shapeList[0]
1364 if error_name != ErrorIf.WrongRank:
1365 assert len(shape) == 4
1366
1367 # Generate comprehensive argument lists
1368 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
1369 paddings = {x for x in itertools.product(*([p_vals] * 4))}
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001370 # Stride must be greater than 1 to force non-integer error
1371 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
1372 s_vals = [x for x in range(startStride, testGen.args.max_pooling_stride + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001373 strides = {x for x in itertools.product(*([s_vals] * 2))}
1374 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)]
1375 kernels = {x for x in itertools.product(*([k_vals] * 2))}
1376
James Ward8b390432022-08-12 20:48:56 +01001377 if opName == "max_pool2d":
1378 accum_dtypes = [None] # max_pool has no accumulate dtype
1379 elif dtype == DType.INT8 or dtype == DType.INT16:
1380 accum_dtypes = [DType.INT32]
1381 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001382 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001383 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001384 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001385 elif error_name is None:
1386 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
1387 else:
1388 # Set to something for the ErrorIf case which has
1389 # incorrect input data-type
1390 accum_dtypes = [DType.INT32]
1391
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001392 if testGen.args.oversize:
1393 # add some oversize argument values
1394 bigStride = 7
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001395 strides.update(
1396 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
1397 )
1398 bigKernel = 9
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001399 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
1400 if max(shape) < 64:
1401 # padding must be less than the kernel size
1402 bigPadding = bigKernel - 1
1403 paddings.update(
1404 {x for x in itertools.product(*([[0, bigPadding]] * 4))}
1405 )
1406
1407 # There are too many parameter combinations, so generate them sparsely,
1408 # very sparse for negative tests
1409 sparsity_factor = 2 if error_name else 500
1410 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
1411
James Ward8b390432022-08-12 20:48:56 +01001412 arg_str = (
1413 "acc{}_st{}_kern{}_pad{}"
1414 if accum_dtypes[0] is not None
1415 else "st{}_kern{}_pad{}"
1416 )
1417
1418 def get_arg_list_element(accum, stride, pad, kern):
1419 # Return tuple containing the formatted argument string and
1420 # the corresponding argument values
1421 arg_str_elems = [
1422 "".join([str(x) for x in stride]),
1423 "".join([str(x) for x in kern]),
1424 "".join([str(x) for x in pad]),
1425 ]
1426 # Note: different order to string
1427 arg_val_elems = [stride, pad, kern]
1428
1429 if accum is not None:
1430 arg_str_elems.insert(0, testGen.typeStr(accum))
1431 arg_val_elems.insert(0, accum)
1432 return (arg_str.format(*arg_str_elems), arg_val_elems)
1433
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001434 n = 0
James Ward8b390432022-08-12 20:48:56 +01001435 for a in accum_dtypes:
1436 for s in sorted(list(strides)):
1437 for p in sorted(list(paddings)):
1438 for k in sorted(list(kernels)):
1439 if error_name in [
1440 ErrorIf.StrideSmallerOne,
1441 ErrorIf.KernelSmallerOne,
1442 ErrorIf.PadSmallerZero,
1443 ErrorIf.PadLargerEqualKernel,
1444 ]:
1445 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
1446 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001447 )
James Ward8b390432022-08-12 20:48:56 +01001448 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
1449 arg_vals = [a, sNew, pNew, kNew]
1450 arg_list.append(get_arg_list_element(*arg_vals))
1451 elif (
1452 n % sparsity == 0
1453 # padding must not exceed the kernel size
1454 and p[0] < k[0]
1455 and p[1] < k[0]
1456 and p[2] < k[1]
1457 and p[3] < k[1]
1458 # the padded shape must exceed the kernel size
1459 and (shape[1] + p[0] + p[1]) > k[0]
1460 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001461 ):
James Ward8b390432022-08-12 20:48:56 +01001462 remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0]
1463 remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1]
1464 if (
1465 # the parameters must produce integer exact output
1466 error_name != ErrorIf.PoolingOutputShapeNonInteger
1467 and remainder_h == 0
1468 and remainder_w == 0
1469 ) or (
1470 error_name == ErrorIf.PoolingOutputShapeNonInteger
1471 and (remainder_h != 0 or remainder_w != 0)
1472 ):
1473 arg_vals = [a, s, p, k]
1474 arg_list.append(get_arg_list_element(*arg_vals))
1475 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001476
1477 return arg_list
1478
1479 @staticmethod
1480 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
1481 arg_list = []
1482
1483 # Enumerate the output types here
1484 if error_name == ErrorIf.WrongOutputType:
1485 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
1486 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00001487 dtypeList = [
1488 DType.BOOL,
1489 DType.INT16,
1490 DType.INT32,
1491 DType.FP16,
1492 DType.BF16,
1493 DType.FP32,
1494 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001495 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00001496 dtypeList = [
1497 DType.BOOL,
1498 DType.INT8,
1499 DType.INT32,
1500 DType.FP16,
1501 DType.BF16,
1502 DType.FP32,
1503 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001504 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00001505 dtypeList = [
1506 DType.BOOL,
1507 DType.INT8,
1508 DType.INT16,
1509 DType.FP16,
1510 DType.BF16,
1511 DType.FP32,
1512 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001513 elif inDtype == DType.BOOL:
1514 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01001515 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00001516 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001517 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00001518 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001519 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00001520 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001521 elif error_name == ErrorIf.WrongInputType:
1522 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001523 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001524 else:
1525 raise Exception("Unexpected input dtype: {}".format(inDtype))
1526
1527 for dtype in dtypeList:
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001528 arg_list.append(("out{}".format(testGen.typeStr(dtype)), [dtype]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001529
1530 return arg_list
1531
1532 @staticmethod
1533 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
1534 arg_list = []
1535
1536 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001537 for outDtype in [
1538 DType.UINT8,
1539 DType.INT8,
1540 DType.INT16,
1541 DType.INT32,
1542 DType.UINT16,
1543 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001544 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001545 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001546 and error_name == ErrorIf.OutputZeroPointNotZero
1547 ):
1548 continue
1549 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001550 outDtype != DType.UINT16
1551 and error_name == ErrorIf.U16OutputZeroPointNotValid
1552 ) or (
1553 inDtype != DType.UINT16
1554 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001555 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001556 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001557 continue
1558 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001559 inDtype == DType.UINT8
1560 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001561 and error_name != ErrorIf.WrongOutputType
1562 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001563 # The only output dtypes for UINT8 are INT8/INT16, skip all others
1564 continue
1565 if (
1566 inDtype not in [DType.INT8, DType.INT16]
1567 and outDtype == DType.UINT8
1568 and error_name != ErrorIf.WrongOutputType
1569 ):
1570 # The only input dtypes for UINT8 are INT8/INT16, skip all others
1571 continue
1572 if (
1573 inDtype == DType.UINT16
1574 and outDtype != DType.INT16
1575 and error_name != ErrorIf.WrongOutputType
1576 ):
1577 # The only output dtype for UINT16 is INT16, skip all others
1578 continue
1579 if (
1580 inDtype != DType.INT16
1581 and outDtype == DType.UINT16
1582 and error_name != ErrorIf.WrongOutputType
1583 ):
1584 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001585 continue
1586 if (
1587 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001588 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001589 ):
1590 continue
1591
1592 for scale32 in [False, True]:
1593 if error_name == ErrorIf.ScaleTrue and not scale32:
1594 continue
1595 elif error_name == ErrorIf.ScaleNotTrue and scale32:
1596 continue
1597 for double_round in [False, True]:
1598 if error_name == ErrorIf.ScaleNotTrue and not double_round:
1599 continue
1600 for per_channel in [False, True]:
1601
1602 if (
1603 inDtype == DType.INT48
1604 and scale32
1605 and error_name != ErrorIf.ScaleTrue
1606 ):
1607 # Illegal condition. Must be scale32=False
1608 continue
1609 if (
1610 double_round
1611 and not scale32
1612 and error_name != ErrorIf.ScaleNotTrue
1613 ):
1614 # Illegal condition. ERROR_IF(!scale32 && double_round)
1615 continue
1616
1617 arg_list.append(
1618 (
1619 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001620 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001621 int(scale32),
1622 int(double_round),
1623 int(per_channel),
1624 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001625 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001626 )
1627 )
1628
1629 return arg_list
1630
1631 @staticmethod
1632 def agMul(testGen, opName, shapeList, dtype, error_name=None):
1633 arg_list = []
1634
1635 if dtype is DType.INT32:
1636 for p in range(testGen.args.num_rand_permutations):
1637
1638 shift = testGen.randInt(0, 32)
1639
1640 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
1641 else:
1642 arg_list.append(("perm0_shift0", [0]))
1643
1644 return arg_list
1645
1646 @staticmethod
1647 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
1648 arg_list = []
1649
1650 arg_list.append(("roundTrue", [True]))
1651 arg_list.append(("roundFalse", [False]))
1652
1653 return arg_list
1654
Luke Hutton57287132023-02-06 14:54:18 +00001655 @staticmethod
1656 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
1657 arg_list = []
1658
1659 arg_list.append(("inverseTrue", [True]))
1660 arg_list.append(("inverseFalse", [False]))
1661
1662 return arg_list
1663
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001664 # Helper function for reshape. Gets some factors of a larger number.
1665 @staticmethod
1666 def getFactors(val, start=1):
1667 factors = []
1668
1669 for i in range(start, int(np.sqrt(val)) + 1):
1670 if (val % i) == 0:
1671 factors.append(i)
1672
1673 return factors
1674
1675 @staticmethod
1676 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
1677 arg_list = []
1678
1679 origShape = shapeList[0]
1680
1681 totalElements = 1
1682 for s in origShape:
1683 totalElements *= s
1684
1685 # This code is NOT fast. Fortunately, the numbers are fairly small.
1686 factors = TosaArgGen.getFactors(totalElements)
1687
1688 for p in range(testGen.args.num_rand_permutations):
1689 newRank = testGen.randInt(1, 7)
1690 if len(factors) < newRank:
1691 continue
1692
1693 found = True
1694 # escape_counter breaks while loop if it continues on for too long
1695 escape_counter = 0
1696 while found:
1697 newShape = []
1698 # Generate newShape ensuring it isn't a duplicate
1699 remainingElements = totalElements
1700 shuffledFactors = testGen.rng.permutation(factors)
1701 for i in range(1, newRank):
1702 # pick rank-1 factors
1703 newShape.append(shuffledFactors[0])
1704 remainingElements = remainingElements // shuffledFactors[0]
1705 shuffledFactors = testGen.rng.permutation(
1706 TosaArgGen.getFactors(remainingElements)
1707 )
1708 newShape.append(remainingElements)
1709
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001710 # Check for duplicates
1711 found = False
1712 for name, other_shape in arg_list:
1713 if other_shape[0] == newShape:
1714 found = True
1715 break
1716
1717 escape_counter += 1
1718 if escape_counter >= 100:
1719 break
1720
1721 if not found:
1722 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
1723
1724 return arg_list
1725
1726 @staticmethod
1727 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
1728 arg_list = []
1729
1730 ifm_shape = shapeList[0]
1731
1732 if error_name == ErrorIf.IndexOutsideBounds:
1733 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
1734 incorrect_small_index = range(-len(ifm_shape), 0)
1735 permutations = [p for p in itertools.permutations(incorrect_large_index)]
1736 permutations.extend(
1737 [p for p in itertools.permutations(incorrect_small_index)]
1738 )
1739 elif error_name == ErrorIf.IndexUsedTwice:
1740 # Create list with a duplicated index
1741 perm_range = list(range(len(ifm_shape)))
1742 index_choice = testGen.rng.choice(range(len(perm_range)))
1743 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
1744 permutations = [p for p in itertools.permutations(perm_range)]
1745
1746 else:
1747 # Get all permutations
1748 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
1749
1750 # Limit to possible permutations from shape dimension or argument setting
1751 limit = min(len(permutations), testGen.args.num_rand_permutations)
1752
1753 # Get random permutation generator that uses all permutations
1754 random_permutations = testGen.rng.permutation(permutations)
1755
1756 # Create list of required amount of permutations
1757 arg_list = [
1758 ("perm{}".format(p), [random_permutations[p].tolist()])
1759 for p in range(limit)
1760 ]
1761 return arg_list
1762
1763 @staticmethod
1764 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
1765 arg_list = []
1766
1767 ifm_shape = shapeList[0]
1768 rank = len(ifm_shape)
1769
1770 for p in range(testGen.args.num_rand_permutations):
1771 start = []
1772 size = []
1773
1774 valid = True
1775
1776 for i in range(rank):
1777 if ifm_shape[i] > 1:
1778 start.append(testGen.randInt(0, ifm_shape[i]))
1779 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
1780
1781 # Invalid slice size?
1782 if size[i] == 0:
1783 valid = False
1784 else:
1785 start.append(0)
1786 size.append(1)
1787
1788 if valid:
1789 # If ERROR_IF test required then incorrect start, size will be returned
1790 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
1791 testGen, error_name, ifm_shape, start, size
1792 )
1793 arg_list.append(("perm{}".format(p), [start, size]))
1794 return arg_list
1795
1796 @staticmethod
1797 def agTile(testGen, opName, shapeList, dtype, error_name=None):
1798 arg_list = []
1799
1800 ifm_shape = shapeList[0]
1801 rank = len(ifm_shape)
1802
1803 for p in range(testGen.args.num_rand_permutations):
1804
1805 # Pick a few random, but small multiple values
1806 # because otherwise this has a tendency to generate
1807 # enormous tensors
1808 multiples = []
1809 for i in range(rank):
1810 if ifm_shape[i] > 1000:
1811 # Multiple of 1 if ifm_shape dimension is large to reduce
1812 # tensor size
1813 multiples.append(1)
1814 elif max(ifm_shape) > 1000:
1815 multiples.append(2)
1816 else:
1817 multiples.append(testGen.randInt(1, 4))
1818 arg_list.append(("perm{}".format(p), [multiples]))
1819
1820 return arg_list
1821
1822 @staticmethod
1823 def agResize(testGen, opName, shapeList, dtype, error_name=None):
1824 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001825 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001826
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001827 def get_aspect_ratio_resize_params():
1828 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
1829 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
1830 invert = testGen.rng.choice((False, True))
1831 letterbox = testGen.rng.choice((False, True))
1832
1833 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
1834 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
1835 scale_y_d = scale_x_d = 1
1836 offset_x = offset_y = 0
1837
1838 if letterbox:
1839 max_border = scale_y_n
1840 border_y = testGen.randInt(low=0, high=max_border)
1841 border_x = 0
1842 else:
1843 # Pillarboxing
1844 border_y = 0
1845 max_border = scale_x_n
1846 border_x = testGen.randInt(low=0, high=max_border)
1847
1848 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1849 offset = (offset_y, offset_x)
1850 border = (border_y, border_x)
1851
1852 return scale, offset, border
1853
1854 def get_upscale_downscale_params():
1855 valid_params = False
1856 while not valid_params:
1857 upscale = testGen.rng.choice((False, True))
1858
1859 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
1860 origin_sampling = testGen.rng.choice((False, True))
1861
1862 if upscale:
1863 shift = testGen.randInt(low=1, high=4)
1864 scale_x_d = scale_y_d = 1
1865 scale_x_n = scale_y_n = (
1866 1 << shift if origin_sampling else 2 << shift
1867 )
1868 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
1869 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
1870 else:
1871 scale_x_n = 1
1872 scale_y_n = 1
1873
1874 # Return list of valid scale_*_d values (max value 4) given input dim shape
1875 def get_valid_denom(ifm_dim):
1876 return [x for x in range(1, 5) if ifm_dim % x == 1]
1877
1878 # Generate list of valid downscale values and choose one randomly
1879 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
1880 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
1881
1882 if not valid_scale_y_ds and not valid_scale_x_ds:
1883 # Bad parameters, skip
1884 continue
1885
1886 if not valid_scale_y_ds:
1887 scale_y_d = 1
1888 else:
1889 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
1890
1891 if not valid_scale_x_ds:
1892 scale_x_d = 1
1893 else:
1894 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
1895
1896 border_x = border_y = 0
1897 offset_y = testGen.randInt(0, 16 * scale_y_n)
1898 offset_x = testGen.randInt(0, 16 * scale_x_n)
1899 valid_params = True
1900
1901 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1902 offset = (offset_y, offset_x)
1903 border = (border_y, border_x)
1904 return scale, offset, border
1905
1906 def get_rand_params():
1907 # Scale
1908 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
1909 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
1910
1911 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
1912 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
1913
1914 # Offsets and border within the scale
1915 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
1916 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
1917 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
1918 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
1919
1920 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1921 offset = (offset_y, offset_x)
1922 border = (border_y, border_x)
1923 return scale, offset, border
1924
1925 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001926 # Exclude illegal {mode, type} configurations. Pick legal output types
1927 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1928 outputDTypeList = [DType.INT8]
1929 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1930 outputDTypeList = [DType.INT16]
1931 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1932 outputDTypeList = [DType.INT32]
1933 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1934 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01001935 elif dtype == DType.FP16:
1936 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01001937 elif dtype == DType.BF16:
1938 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001939 elif dtype == DType.FP32:
1940 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001941 elif error_name == ErrorIf.WrongInputType:
1942 # If an incorrect input type is used then we set a 'correct'
1943 # output type to avoid other errors
1944 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
1945 else:
1946 continue
1947
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001948 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
1949
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001950 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001951 perm = 0
1952 while perm < testGen.args.num_rand_permutations:
1953 # Random choice of type of params we are testing
1954 _rnd_param_fn = testGen.rng.choice(
1955 (
1956 get_rand_params,
1957 get_upscale_downscale_params,
1958 get_aspect_ratio_resize_params,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001959 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001960 )
1961 scale, offset, border = _rnd_param_fn()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001962
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001963 # Expand params for bounds-checking
1964 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
1965 (offset_y, offset_x) = offset
1966 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001967
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001968 # Make sure output dimensions OH and OW are integers
1969 partial_output_y = (
1970 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
1971 )
1972 partial_output_x = (
1973 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
1974 )
1975 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
1976 if (
1977 partial_output_y % scale_y_d == 0
1978 and partial_output_x % scale_x_d == 0
1979 ):
1980 # Skip this test as it doesn't produce NonInteger output
1981 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001982 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001983 else:
1984 while partial_output_y % scale_y_d != 0:
1985 scale_y_d -= 1
1986 while partial_output_x % scale_x_d != 0:
1987 scale_x_d -= 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001988
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001989 output_y = partial_output_y // scale_y_d + 1
1990 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001991
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001992 if (
1993 output_y >= testGen.args.max_resize_output_dim
1994 or output_x >= testGen.args.max_resize_output_dim
1995 ) and error_name is None:
1996 # Skip positive test if output dim will be too high
1997 # Avoid high test latency and OOM issues
1998 perm += 1
1999 continue
2000
2001 if (
2002 output_y <= 0
2003 or output_y >= MAX_RESIZE_DIMENSION
2004 or output_x <= 0
2005 or output_x >= MAX_RESIZE_DIMENSION
2006 ):
2007 # Output dimensions out of scope
2008 if error_name is not None and perm > 0:
2009 # As long as we have one ERROR_IF test, don't worry
2010 # about creating all the other permutations
2011 perm += 1
2012 continue
2013
2014 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
2015 (
2016 output_y + scale_y_d >= MAX_RESIZE_DIMENSION
2017 and output_y - scale_y_d < 1
2018 )
2019 or (
2020 output_x + scale_x_d >= MAX_RESIZE_DIMENSION
2021 and output_x - scale_x_d < 1
2022 )
2023 ):
2024 # Can't create a negative test with these params as it
2025 # will create invalid output size
2026 if perm > 0:
2027 perm += 1
2028 continue
2029
2030 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
2031 offset = [offset_y, offset_x]
2032 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002033
2034 # Common for all data types
2035 if error_name is not None:
2036 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002037 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002038 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002039 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002040 outputDTypeNew,
2041 ) = TosaErrorIfArgGen.eiResizeErrorIf(
2042 testGen,
2043 error_name,
2044 mode,
2045 dtype,
2046 shapeList,
2047 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002048 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002049 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002050 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002051 )
2052 else:
2053 outputDTypeNew = outputDType
2054
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002055 arg_to_append = (
2056 arg_str.format(
2057 "N" if mode == ResizeMode.NEAREST else "B",
2058 testGen.typeStr(outputDTypeNew),
2059 scale[0],
2060 scale[1],
2061 scale[2],
2062 scale[3],
2063 offset[0],
2064 offset[1],
2065 border[0],
2066 border[1],
2067 ),
2068 [
2069 mode,
2070 scale,
2071 offset,
2072 border,
2073 dtype,
2074 outputDTypeNew,
2075 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002076 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002077 if arg_to_append in arg_list:
2078 # Skip already generated test params
2079 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002080
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002081 # Valid permutation
2082 perm += 1
2083 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002084 return arg_list
2085
2086 @staticmethod
2087 def agTable(testGen, opName, shapeList, dtype, error_name=None):
2088 arg_list = []
2089
2090 if dtype == DType.INT8:
2091 table = np.int32(
2092 testGen.rng.integers(low=-128, high=128, size=[256])
2093 ).tolist()
2094 else: # INT16
2095 table = np.int32(
2096 testGen.rng.integers(low=-32768, high=32768, size=[513])
2097 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07002098 # Make sure all slopes are within REQUIRE min/max 16-bit int
2099 for idx in range(len(table) - 1):
2100 slope = table[idx + 1] - table[idx]
2101 # Alter the next table entry to force the slope to be ok
2102 if slope > 32767:
2103 table[idx + 1] -= slope - 32767
2104 if slope < -32768:
2105 table[idx + 1] -= slope + 32768
2106 slope = table[idx + 1] - table[idx]
2107 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002108 arg_list.append(
2109 (
2110 "",
2111 [table],
2112 )
2113 )
2114 return arg_list
2115
2116 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
2117 # CondIf generates the condition values here.
2118 # Convert to tensors in the build function, along with the
2119 # then and else blocks
2120 arg_list = []
2121
2122 for c in [False, True]:
2123 arg_list.append(("cond{}".format(int(c)), [c]))
2124
2125 return arg_list
2126
2127 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
2128 # While loop: 0 iterations, 1, more than 1
2129 arg_list = []
2130
2131 for iter in [0, 1, 4]:
2132 arg_list.append(("iter{}".format(iter), [iter]))
2133
2134 return arg_list