blob: 97ff237b28026da779d5638d441500b05afa9c0f [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_error_if import ErrorIf
9from generator.tosa_error_if import TosaErrorIfArgGen
James Ward8b390432022-08-12 20:48:56 +010010from generator.tosa_utils import get_accum_dtype_from_tgTypes
11from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010012from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from serializer.tosa_serializer import DTypeNames
14from tosa.DType import DType
15from tosa.Op import Op
16from tosa.ResizeMode import ResizeMode
17
18# DTypeNames, DType, Op and ResizeMode are convenience variables to the
19# flatc-generated types that should be enums, but aren't
20
21
22class TosaQuantGen:
23 """QuantizedInfo random generator helper functions.
24
25 Specify with 'qgen': in the operator defintion.
26 """
27
28 def __init__(self):
29 pass
30
31 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000032 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010033
34 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010035 if testGen.args.zeropoint is not None:
36 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010037 return testGen.randInt(-128, 128)
38 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010039 if testGen.args.zeropoint is not None:
40 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010041 return testGen.randInt(0, 256)
42 elif error_name in [
43 ErrorIf.InputZeroPointNotZero,
44 ErrorIf.WeightZeroPointNotZero,
45 ErrorIf.OutputZeroPointNotZero,
46 ]:
47 zero_point = testGen.randInt(-128, 128)
48 if zero_point == 0:
49 zero_point = 1
50 return zero_point
51 return 0
52
53 @staticmethod
54 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010055 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000056 qinfo = [
57 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
58 TosaQuantGen.getZeroPoint(testGen, dtype),
59 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010060 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000061 qinfo = [
62 TosaQuantGen.getZeroPoint(testGen, dtype),
63 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
64 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010065 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000066 qinfo = [
67 TosaQuantGen.getZeroPoint(testGen, dtype),
68 TosaQuantGen.getZeroPoint(testGen, dtype),
69 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010070 return qinfo
71
72 @staticmethod
73 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010074 if isinstance(dtype_or_dtypeList, list):
75 # a list of [input, weights, accumulator] dtypes
76 dtypeList = dtype_or_dtypeList
77 else:
78 # an int, [input, weights, accumulator] dtypes are the same
79 dtypeList = [dtype_or_dtypeList] * 3
80
81 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000082 qinfo = [
83 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
84 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
85 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010086 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000087 qinfo = [
88 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
89 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
90 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010091 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000092 qinfo = [
93 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
94 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
95 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010096 return qinfo
97
98 @staticmethod
99 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100100 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000101 qinfo = [
102 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
103 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
104 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100105 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000106 qinfo = [
107 TosaQuantGen.getZeroPoint(testGen, dtype),
108 TosaQuantGen.getZeroPoint(testGen, dtype),
109 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100110 return qinfo
111
112 @staticmethod
113 def computeMultiplierAndShift(scaleFp, scale32):
114 # Derived from computeMultiplierAndShiftTosaScale32
115 # Provide a floating-point scaling factor and the scale32 parameter
116 # to compute the multiplier and shift
117
118 if scale32:
119 scaleBits = 31
120 else:
121 scaleBits = 15
122
123 m, shift = math.frexp(scaleFp)
124
125 if scaleFp < 0.0:
126 m = -m
127
128 multiplier = round(m * (1 << scaleBits))
129 assert multiplier <= (1 << scaleBits)
130
131 if multiplier == (1 << scaleBits):
132 multiplier = multiplier // 2
133 shift = shift + 1
134
135 shift = (-shift) + scaleBits
136 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
137 # scaleFp, scaleBits, m, multiplier, shift))
138
139 # Adjust multiplier such that shift is in allowed value range.
140 if shift == 0:
141 multiplier = multiplier // 4
142 shift = shift + 2
143 elif shift == 1:
144 multiplier = multiplier // 2
145 shift = shift + 1
146 elif shift == 63:
147 multiplier = multiplier * 2
148 shift = shift - 1
149
150 assert multiplier <= (1 << scaleBits)
151 assert shift >= 2 and shift <= 62
152
153 return multiplier, shift
154
155
156class TosaTensorGen:
157 """Tensor generators create a shape list for the placeholder and const tensor
158 data operands for the operator.
159
160 The actual random data is generated separately for each test.
161 """
162
163 def __init__(self):
164 pass
165
166 @staticmethod
167 def tgBasic(testGen, opName, rank, error_name=None):
168 pl, const = opName["operands"]
169 shape = testGen.makeShape(rank)
170
171 # Constrict the overall size of the shape when creating ERROR_IF tests
172 if error_name:
173 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
174
175 shape_list = []
176 for i in range(pl + const):
177 shape_list.append(shape.copy())
178
Luke Huttona4e48ca2023-02-22 11:53:48 +0000179 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100180 if error_name == ErrorIf.RankMismatch:
181 if rank == 1 and i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
183 elif i != 1:
184 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
185
186 return shape_list
187
188 @staticmethod
189 def tgNHWC(testGen, opName, rank, error_name=None):
190 pl, const = opName["operands"]
191
192 if error_name != ErrorIf.WrongRank:
193 assert rank == 4
194
195 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000196 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100197
198 # Constrict the overall size of the shape when creating ERROR_IF tests
199 if error_name and error_name != ErrorIf.MaxDimExceeded:
200 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
201
202 shape_list = []
203 for i in range(pl + const):
204 shape_list.append(shape.copy())
205
206 return shape_list
207
208 @staticmethod
209 def tgScatter(testGen, opName, rank, error_name=None):
210 pl, const = opName["operands"]
211
212 assert pl == 2
213 assert const == 0
214 if error_name != ErrorIf.WrongRank:
215 assert rank == 3
216
217 values_in_shape = testGen.makeShape(rank)
218
219 # ignore max batch size if target shape is set
220 if testGen.args.max_batch_size and not testGen.args.target_shapes:
James Ward30124a82023-02-02 14:56:33 +0000221 values_in_shape[0] = min(values_in_shape[0], testGen.args.max_batch_size)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100222
223 W = testGen.randInt(
224 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
225 )
226 # Constrict W if one dimension is too large to keep tensor size reasonable
227 if max(values_in_shape) > 5000:
228 W = testGen.randInt(0, 16)
229
230 input_shape = [values_in_shape[0], W, values_in_shape[2]]
231
232 shape_list = []
233 shape_list.append(values_in_shape.copy())
234 shape_list.append(input_shape.copy())
235
236 return shape_list
237
238 @staticmethod
239 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
240 shape = testGen.makeShape(rank)
241
242 pl, const = op["operands"]
243
244 shape_list = []
245
246 # Choose one of the inputs to broadcast
247 # Note: Simplifies OutputShaper code if we don't change first shape for errors
248 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Jerry Ge135c9552023-05-23 20:59:32 +0000249 fuzz_idx = testGen.randInt(0, rank)
250
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100251 for i in range(pl + const):
252 shape_bcast = shape.copy()
253
Jerry Ge135c9552023-05-23 20:59:32 +0000254 # To test broadcasting, the chosen fuzz index dimension should not be 1
255 if shape_bcast[fuzz_idx] == 1:
256 shape_bcast[fuzz_idx] += 1
257
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100258 # If the chosen input, pick a random index to broadcast
259 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000260 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100261 # Add one rank to the shape (or more for rank of 1)
262 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
263 shape_bcast = np.concatenate(
264 (shape_bcast, testGen.makeShape(extra_ranks))
265 )
266 if rank != 1:
267 # Either keep the extra rank, or remove it
268 new_len = testGen.rng.choice([-2, len(shape_bcast)])
269 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000270 elif error_name == ErrorIf.BroadcastShapesMismatch:
271 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100272 else:
273 shape_bcast[fuzz_idx] = 1
274
275 shape_list.append(shape_bcast)
276
277 return shape_list
278
279 @staticmethod
280 def tgConv2D(testGen, op, rank, error_name=None):
281 pl, const = op["operands"]
282
283 if error_name != ErrorIf.WrongRank:
284 assert rank == 4
285
286 # IFM dimensions are NHWC
287 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000288 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100289
290 # Constrict the overall size of the shape when creating ERROR_IF tests
291 if error_name:
292 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
293 ifm_shape, max_dim=24, max_items=10000
294 )
295
296 # Get the filter height/width from the operator parameters
297 filter_hw = op["filter"]
298
299 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000300 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100301
302 # The filter dimensions are OHWI
303 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
304
305 # The bias is OC
306 bias_shape = np.asarray([ofm_depth])
307
308 return [ifm_shape, filter_shape, bias_shape]
309
310 @staticmethod
311 def tgConv3D(testGen, op, rank, error_name=None):
312 pl, const = op["operands"]
313
314 if error_name != ErrorIf.WrongRank:
315 assert rank == 5
316
317 # IFM dimensions are NDHWC
318 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000319 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100320
321 # Constrict the overall size of the shape when creating ERROR_IF tests
322 if error_name:
323 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
324 ifm_shape, max_dim=24, max_items=10000
325 )
326
327 # Get the filter depth/height/width from the operator parameters
328 filter_dhw = op["filter"]
329
330 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000331 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100332
333 # The filter dimensions are ODHWI
334 filter_shape = np.asarray(
335 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
336 )
337
338 # The bias is OC
339 bias_shape = np.asarray([ofm_channel])
340
341 return [ifm_shape, filter_shape, bias_shape]
342
343 @staticmethod
344 def tgTransposeConv2D(testGen, op, rank, error_name=None):
345 pl, const = op["operands"]
346
347 if error_name != ErrorIf.WrongRank:
348 assert rank == 4
349
350 # IFM dimensions are NHWC
351 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000352 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100353
354 # Constrict the overall size of the shape when creating ERROR_IF tests
355 if error_name:
356 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
357 ifm_shape, max_dim=24, max_items=10000
358 )
359
360 # Get the filter height/width from the operator parameters
361 filter_hw = op["filter"]
362
363 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000364 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100365
366 # The filter dimensions are OHWI
367 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
368
369 # The bias is OC
370 bias_shape = np.asarray([ofm_depth])
371
372 return [ifm_shape, filter_shape, bias_shape]
373
374 @staticmethod
375 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
376 pl, const = op["operands"]
377
378 if error_name != ErrorIf.WrongRank:
379 assert rank == 4
380 assert pl == 1 and const == 2
381
382 # IFM dimensions are NHWC
383 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000384 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100385
386 # Constrict the overall size of the shape when creating ERROR_IF tests
387 if error_name:
388 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
389 ifm_shape, max_dim=24, max_items=10000
390 )
391
392 # Get the filter height/width from the operator parameters
393 # Filter is KH, HW, C, M
394 filter_hw = op["filter"]
395
396 # Generate a random OFM depth, but don't let it get too big because
397 # the output depth is M * C
398 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000399 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100400 ) + 1
401
402 # The filter dimensions are HWCM
403 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
404
405 # The bias is M * C
406 bias_shape = np.asarray([ifm_shape[3] * filter_m])
407
408 return [ifm_shape, filter_shape, bias_shape]
409
410 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000411 def tgFFT2d(testGen, op, rank, error_name=None):
412 pl, const = op["operands"]
413
414 if error_name != ErrorIf.WrongRank:
415 assert rank == 3
416 assert pl == 2 and const == 0
417
418 # IFM dimensions are NHW
419 ifm_shape = testGen.makeShape(rank)
420
421 # Select nearest lower power of two from input height and width
422 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
423 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
424
425 # Constrict the overall size of the shape when creating ERROR_IF tests
426 if error_name:
427 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
428
429 # Generate an invalid kernel that is not a power of two
430 if error_name == ErrorIf.KernelNotPowerOfTwo:
431 inc_h = 2 if ifm_shape[1] == 1 else 1
432 inc_w = 2 if ifm_shape[2] == 1 else 1
433 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
434 selected_inc = testGen.rng.choice(inc_choices)
435 ifm_shape[1] += selected_inc[0]
436 ifm_shape[2] += selected_inc[1]
437
438 ifm_shape = testGen.constrictBatchSize(ifm_shape)
439
440 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
441 if error_name == ErrorIf.FFTInputShapeMismatch:
442 modify_shape = testGen.rng.choice([0, 1])
443 # Only modify kernel (H, W)
444 modify_dim = testGen.rng.choice([1, 2])
445 ifm_shapes[modify_shape][modify_dim] *= 2
446
447 return [ifm_shapes[0], ifm_shapes[1]]
448
449 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000450 def tgRFFT2d(testGen, op, rank, error_name=None):
451 pl, const = op["operands"]
452
453 if error_name != ErrorIf.WrongRank:
454 assert rank == 3
455 assert pl == 1 and const == 0
456
457 # IFM dimensions are NHW
458 ifm_shape = testGen.makeShape(rank)
459
460 # Select nearest lower power of two from input height and width
461 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
462 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
463
464 # Constrict the overall size of the shape when creating ERROR_IF tests
465 if error_name:
466 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
467
468 # Generate an invalid kernel that is not a power of two
469 if error_name == ErrorIf.KernelNotPowerOfTwo:
470 # We must increment by 2 if current size is 1
471 inc_h = 2 if ifm_shape[1] == 1 else 1
472 inc_w = 2 if ifm_shape[2] == 1 else 1
473 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
474 selected_inc = testGen.rng.choice(inc_choices)
475 ifm_shape[1] += selected_inc[0]
476 ifm_shape[2] += selected_inc[1]
477
James Ward30124a82023-02-02 14:56:33 +0000478 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000479
480 return [ifm_shape]
481
482 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100483 def tgFullyConnected(testGen, op, rank, error_name=None):
484 pl, const = op["operands"]
485
486 if error_name != ErrorIf.WrongRank:
487 assert rank == 2
488
489 input_shape = testGen.makeShape(rank)
490
491 # Constrict the overall size of the shape when creating ERROR_IF tests
492 if error_name:
493 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
494
495 filter_oc = testGen.rng.integers(
496 low=testGen.args.tensor_shape_range[0],
497 high=testGen.args.tensor_shape_range[1],
498 size=1,
499 )[0]
500 filter_shape = np.asarray([filter_oc, input_shape[1]])
501
502 bias_shape = np.asarray([filter_oc])
503
504 return [input_shape, filter_shape, bias_shape]
505
506 @staticmethod
507 def tgMatmul(testGen, op, rank, error_name=None):
508 pl, const = op["operands"]
509
510 if error_name != ErrorIf.WrongRank:
511 assert rank == 3
512 assert pl == 2 and const == 0
513
514 a_shape = testGen.makeShape(rank)
515
516 # Constrict the overall size of the shape when creating ERROR_IF tests
517 if error_name:
518 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
519
520 # Get a random number for b_oc even if target shape is defined
521 b_oc = np.int32(
522 testGen.rng.integers(
523 low=testGen.args.tensor_shape_range[0],
524 high=testGen.args.tensor_shape_range[1],
525 size=1,
526 )
527 )[0]
528 # If N or H is large let b_oc be 1 to reduce output tensor size
529 if max(a_shape) > 1000:
530 b_oc = 1
531
532 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
533 return [a_shape, b_shape]
534
535 @staticmethod
536 def tgConcat(testGen, opName, rank, error_name=None):
537 pl, const = opName["operands"]
538 shape = testGen.makeShape(rank)
539
540 # Create extra tensors to concat.
541 # Take into account value of pl when getting maximum number of concats
542 num_tensors = testGen.randInt(0, 4)
543 shape_list = []
544 for i in range(pl + const + num_tensors):
545 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
546 remove = testGen.rng.choice([True, False])
547 wrongShape = shape.copy()
548
549 if remove and len(shape) > 1:
550 wrongShape = wrongShape[1:]
551 else:
552 wrongShape = list(wrongShape)
553 wrongShape.append(testGen.rng.integers(1, 10))
554
555 shape_list.append(wrongShape)
556 else:
557 shape_list.append(shape.copy())
558
559 return shape_list
560
561 @staticmethod
562 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
563 if error_name in [
564 ErrorIf.AxisSmallerZero,
565 ErrorIf.AxisLargerRank,
566 ErrorIf.ConcatInputRankMismatch,
567 ]:
568 return shapeList
569
570 # Split concat shape along axis to allow for multiple const inputs
571 # without making too many large tensors
572 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
573 # If axis can't be split we still need to invalidate other dimensions
574 if error_name == ErrorIf.ConcatInputDimMismatch:
575 for shape in shapeList[1:]:
576 # Negative test shapeLists are created individually for each test,
577 # so no need to copy the shape before altering it.
578 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
579 return shapeList
580
581 # Create copy of shape we are going to split (so we don't alter shapeList)
582 shape = shapeList[0].copy()
583 # Add original shape as first input
584 new_shapeList = [shape.copy()]
585 length_on_axis = shape[axis]
586 remaining_length = length_on_axis
587 for i in range(len(shapeList) - 2):
588 # Calculate split on axis and remaining value
589 split_shape_val = int(shape[axis] / 2)
590 remaining_length = remaining_length - split_shape_val
591
592 # Append new shape, and set remaining shape
593 shape[axis] = split_shape_val
594 new_shapeList.append(shape.copy())
595
596 # invalidate dimensions
597 if error_name == ErrorIf.ConcatInputDimMismatch:
598 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
599 else:
600 shape[axis] = remaining_length
601
602 if i == len(shapeList) - 3:
603 new_shapeList.append(shape.copy())
604
605 return new_shapeList
606
607
608class TosaTensorValuesGen:
609 """Tensor Value generators create the random data for each test."""
610
611 def __init__(self):
612 pass
613
614 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000615 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100616 pCount, cCount = op["operands"]
617
618 tens = []
619 tens.extend(
620 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
621 )
622 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
623
624 return tens
625
626 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000627 def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100628 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100629 pCount, cCount = op["operands"]
630 assert (
631 pCount == 1 and cCount == 0
632 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100633 # Must create tensors with values within accumulator (int32) negatable
634 # range
635 max_val = (1 << 31) - 1
636 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100637 arr = np.int32(
638 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
639 )
640 placeholders = []
641 placeholders.append(
642 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
643 )
644 return placeholders
645 else:
646 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000647 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100648 )
649
650 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000651 def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100652 if dtypeList[0] == DType.INT32 and error_name is None:
653 # Make sure the operation does not cause value saturation - where
654 # the number wraps due to limited number of bits to store the answer
655 pCount, cCount = op["operands"]
656 assert (
657 pCount == 2 and cCount == 0
658 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
659 placeholders = []
660 add = op["op"] == Op.ADD
661 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
662 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
663 if add:
664 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
665 else:
666 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
667
668 # Work out the saturation limits
669 max_i32 = (1 << 31) - 1
670 min_i32 = -(1 << 31)
671 max_arr = np.full(shapeList[1], max_i32)
672 min_arr = np.full(shapeList[1], min_i32)
673
674 # Find how much values exceed the maximum/minimums
675 sat_max_arr = np.maximum(res_arr - max_arr, 0)
676 sat_min_arr = np.minimum(res_arr - min_arr, 0)
677
678 if not add:
679 # Swap saturation values and negate values as we need to perform opposite operations
680 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
681
682 # Create new array of unsaturated values by clipping values as needed
683 b_unsat_arr = b_arr
684 if (sat_max_arr != 0).any():
685 # Clip values that cause saturation
686 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
687 # Reduce axes in unsaturated tensor to match original tensor
688 for axis, dim in enumerate(b_arr.shape):
689 if dim != b_unsat_arr.shape[axis]:
690 assert (
691 dim == 1
692 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
693 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
694
695 if (sat_min_arr != 0).any():
696 # Clip values that cause saturation
697 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
698 # Reduce axes in unsaturated tensor to match original tensor
699 for axis, dim in enumerate(b_arr.shape):
700 if dim != b_unsat_arr.shape[axis]:
701 assert (
702 dim == 1
703 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
704 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
705
706 placeholders.append(
707 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
708 )
709 placeholders.append(
710 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
711 )
712
713 return placeholders
714 else:
715 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000716 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100717 )
718
719 @staticmethod
720 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000721 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100722 ):
723 if dtypeList[0] in (
724 DType.INT32,
725 DType.INT16,
726 DType.INT8,
727 ):
728 # Limit input tensors with cond_if_binary or while_loop to stop
729 # saturation of add/sub ops with int32 and keep all logical shift
730 # values between 0 to 31 for int16 or int8
731 pCount, cCount = op["operands"]
732 pRemain = pCount
733 placeholders = []
734 for idx, shape in enumerate(shapeList[:]):
735 if dtypeList[0] == DType.INT32:
736 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
737 else:
738 arr = np.int32(
739 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
740 )
741 if pRemain > 0:
742 placeholders.append(
743 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
744 )
745 pRemain -= 1
746 else:
747 placeholders.append(
748 testGen.ser.addConst(shape, dtypeList[idx], arr)
749 )
750
751 return placeholders
752 else:
753 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000754 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100755 )
756
757 @staticmethod
758 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000759 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100760 ):
761 pCount, cCount = op["operands"]
762 # Force value of operand[1] to be within [0, num_bits]
763 assert (
764 pCount == 2 and cCount == 0
765 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
766
767 placeholders = []
768 for idx, shape in enumerate(shapeList[:]):
769 if idx == 1:
770 if dtypeList[idx] == DType.INT8:
771 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
772 elif dtypeList[idx] == DType.INT16:
773 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
774 elif dtypeList[idx] == DType.INT32:
775 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
776 elif error_name == ErrorIf.WrongInputType:
777 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
778 else:
779 raise Exception("OpArithmeticRightShift: invalid input dtype")
780 else:
781 arr = testGen.getRandTensor(shape, dtypeList[idx])
782 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
783
784 return placeholders
785
786 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000787 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100788 # Set datatype of condition tensor to boolean
789 dtypeList[0] = DType.BOOL
790
791 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000792 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100793 )
794
795 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000796 def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100797 if error_name is None:
798 pCount, cCount = op["operands"]
799 assert (
800 pCount == 2 and cCount == 0
801 ), "Op.INTDIV must have 2 placeholders, 0 consts"
802
803 placeholders = []
804
805 # Two invalid cases for Op.INTDIV:
806 # 1. divisor == 0
807 # 2. dividend == -(1<<31) and divisor == -1
808 while True:
809 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
810 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
811
812 if (divisor_arr == 0).any():
813 continue
814
815 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
816 continue
817
818 break
819
820 placeholders.append(
821 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
822 )
823 placeholders.append(
824 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
825 )
826
827 return placeholders
828 else:
829 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000830 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100831 )
832
833 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000834 def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100835 if error_name is None:
836 pCount, cCount = op["operands"]
837 assert (
838 pCount == 2 and cCount == 0
839 ), "Op.MUL must have 2 placeholders, 0 consts"
840
841 tens = []
James Ward24dbc422022-10-19 12:20:31 +0100842 if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100843 tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
844 else:
845 placeholders = []
846
847 # Make sure multiply result in int32 range
848 shift = testArgs[0]
849 if dtypeList[0] == DType.INT8:
850 num_bits = 8
851 elif dtypeList[0] == DType.INT16:
852 num_bits = 16
853 elif dtypeList[0] == DType.INT32:
854 num_bits = 32
855 elif error_name == ErrorIf.WrongInputType:
856 num_bits = 8
857 else:
858 raise Exception("OpMul: invalid input dtype")
859
860 for idx, shape in enumerate(shapeList[:]):
861 low = -(2 ** (num_bits - 1))
862 high = (2 ** (num_bits - 1)) - 1
863
864 a_arr = np.int32(
865 testGen.rng.integers(low=low, high=high, size=shapeList[0])
866 )
867 b_arr = np.int32(
868 testGen.rng.integers(low=low, high=high, size=shapeList[1])
869 )
870
871 i = 0
872 while True:
873
874 a_arr_64 = a_arr.astype(np.int64)
875 b_arr_64 = b_arr.astype(np.int64)
876
877 if shift > 0:
878 rounding = 1 << (shift - 1)
879 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
880 else:
881 result_arr = a_arr_64 * b_arr_64
882
883 if (result_arr > -(2**31)).all() and (
884 result_arr <= ((2**31) - 1)
885 ).all():
886 break
887
888 i = i + 1
889 a_arr = a_arr // 2
890 b_arr = b_arr // 2
891
892 placeholders.append(
893 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
894 )
895 placeholders.append(
896 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
897 )
898
899 tens.extend(placeholders)
900
901 return tens
902 else:
903 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000904 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100905 )
906
907 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000908 def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100909 count = len(shapeList) - testGen.args.num_const_inputs_concat
910 if count < 1:
911 count = 1
912 if testGen.args.num_const_inputs_concat == 0:
913 count = len(shapeList)
914
915 # Ensure axis is an int
916 testArgs[0] = int(testArgs[0])
917
918 shapeList = TosaTensorGen.tgConcatConstInput(
919 testGen, shapeList, testArgs[0], error_name
920 )
921
922 tens = []
923 tens.extend(
924 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
925 )
926 tens.extend(testGen.buildConstTensors(shapeList[count:], dtypeList[count:]))
927
928 return tens
929
930 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000931 def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100932 pCount, cCount = op["operands"]
933 assert (
934 pCount == 2 and cCount == 0
935 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
936 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
937 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
938 placeholders = []
939 placeholders.append(
940 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
941 )
942 placeholders.append(
943 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
944 )
945
946 return placeholders
947
948 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000949 def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100950 if error_name is None:
951 pCount, cCount = op["operands"]
952 assert (
953 pCount == 2 and cCount == 0
954 ), "Op.EQUAL must have 2 placeholders, 0 consts"
955 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
956 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
957 # Using random numbers means that it will be very unlikely that
958 # there are any matching (equal) values, therefore force that
959 # there are twice the number of matching values as the tensor rank
960 for num in range(0, len(shapeList[0]) * 2):
961 a_index = []
962 b_index = []
963 # Choose an index in each axis for the whole shape
964 for axis in range(0, len(shapeList[0])):
965 # Index can be up to the largest dimension in both shapes
966 index = np.int32(
967 testGen.rng.integers(
968 0, max(shapeList[0][axis], shapeList[1][axis])
969 )
970 )
971 # Reduce the index down to a shape's dim for broadcasting
972 a_index.append(min(shapeList[0][axis] - 1, index))
973 b_index.append(min(shapeList[1][axis] - 1, index))
974
975 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
976
977 placeholders = []
978 placeholders.append(
979 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
980 )
981 placeholders.append(
982 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
983 )
984 return placeholders
985 else:
986 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000987 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100988 )
989
990 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000991 def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100992 if dtypeList[0] == DType.INT32:
993 pCount, cCount = op["operands"]
994 assert (
995 pCount == 1 and cCount == 0
996 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
997 # Limit values so that the sum cannot exceed the range of an int32 during
998 # summation of any axis
999 range_val = int((1 << 31) / max(shapeList[0]))
1000 values_arr = np.int32(
1001 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1002 )
1003 placeholders = []
1004 placeholders.append(
1005 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1006 )
1007 return placeholders
1008 else:
1009 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001010 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001011 )
1012
1013
1014class TosaArgGen:
1015 """Argument generators create exhaustive or random lists of attributes for
1016 operators that take attributes or other parameters.
1017
1018 The return value is a list of (descriptive_name, [arglist]) tuples where
1019 the descriptive_name is appended to the test name and the arglist is expanded
1020 as arguments to the operator build function.
1021 """
1022
1023 def __init__(self):
1024 pass
1025
1026 @staticmethod
1027 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1028 """A trivial argument generator for operators that don't take any
1029 non-tensor arguments"""
1030 return [("", [])]
1031
1032 @staticmethod
1033 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1034 """Build the axis argument for operators that take a single axis"""
1035 axes = []
1036 shape = shapeList[0]
1037
1038 if error_name == ErrorIf.AxisSmallerZero:
1039 small_axis = testGen.rng.integers(-5, 0)
1040 axes.append(("axis{}".format(small_axis), [small_axis]))
1041 elif error_name == ErrorIf.AxisLargerRank:
1042 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
1043 axes.append(("axis{}".format(large_axis), [large_axis]))
1044 else:
1045 for a in range(0, len(shape)):
1046 axes.append(("axis{}".format(a), [a]))
1047
1048 return axes
1049
1050 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001051 def _calculate_sparsity(num_tests, sparsity_factor):
1052 sparsity = num_tests // sparsity_factor + 1
1053 # If there are only a small number of tests, just select them all
1054 if sparsity < 13:
1055 sparsity = 1
1056 # To get a variety of parameter combinations sparsity should not be a
1057 # multiple of 2, 3 or 5
1058 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1059 sparsity += 1
1060 return sparsity
1061
1062 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001063 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001064 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001065 arg_list = []
1066
Jeremy Johnson0c716862023-04-13 17:18:19 +01001067 if testGen.args.level8k and error_name is not None:
1068 # Don't produce negative large tests
1069 return arg_list
1070
1071 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001072 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001073 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001074 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001075
James Ward8b390432022-08-12 20:48:56 +01001076 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1077
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001078 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001079 conv3d = opName.startswith("conv3d")
1080 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001081 if error_name != ErrorIf.WrongRank:
1082 assert len(ifm_shape) == rank
1083 assert len(filter_shape) == rank
1084
Jeremy Johnson0c716862023-04-13 17:18:19 +01001085 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001086 k_rank = rank - 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001087 k_pos = 0 if opName.startswith("depthwise") else 1
1088 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001089
Jeremy Johnson0c716862023-04-13 17:18:19 +01001090 if not testGen.args.level8k:
1091 # Generate comprehensive argument lists
1092 # - except for named errors, which use specific invalid value(s)
1093 if error_name == ErrorIf.PadSmallerZero:
1094 p_vals = [testGen.rng.choice(range(-5, 0))]
1095 else:
1096 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1097 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1098 if error_name == ErrorIf.StrideSmallerOne:
1099 # Can't use stride=0, as it is used to derive output shape, as a divisor
1100 s_vals = [testGen.rng.choice(range(-5, 0))]
1101 else:
1102 # Stride must be greater than 1 to force non-integer error
1103 startStride = (
1104 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001105 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001106 s_vals = [
1107 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1108 ]
1109 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1110 if error_name == ErrorIf.DilationSmallerOne:
1111 d_vals = [testGen.rng.choice(range(-5, 1))]
1112 else:
1113 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1114 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001115
Jeremy Johnson0c716862023-04-13 17:18:19 +01001116 if not error_name and testGen.args.oversize:
1117 # add some oversize argument values
1118 if max(ifm_shape) < 64:
1119 bigPadding = 9
1120 paddings.update(
1121 {
1122 x
1123 for x in itertools.product(
1124 *([[0, bigPadding]] * (k_rank * 2))
1125 )
1126 }
1127 )
1128 bigStride = 8
1129 strides.update(
1130 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1131 )
1132 bigDilation = 7
1133 dilations.update(
1134 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1135 )
1136 max_dim_size = None
1137
1138 # There are too many parameter combinations, so generate them sparsely,
1139 # very sparse for negative tests
1140 sparsity_factor = 2 if error_name else 120
1141 sparsity = TosaArgGen._calculate_sparsity(
1142 len(paddings) * len(strides) * len(dilations), sparsity_factor
1143 )
1144 else:
1145 # Only test 8k levels boundaries
1146 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1147 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1148 bigPadding = bigKernel
1149
1150 dilation_shape = [1] * k_rank
1151 pad_shape = [0] * k_rank * 2
1152 if conv3d:
1153 # Small stride apart from for big kernel (see below) to keep
1154 # tensor size/calculation small
1155 stride_shape = [1] * k_rank
1156 for idx in range(k_rank):
1157 pad_offset = idx * 2
1158 if k_shape[idx] == bigKernel:
1159 # Padding shape needs to account for tensor shape
1160 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1161 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1162 # Big stride to reduce output size
1163 stride_shape[idx] = bigKernel
1164 else:
1165 # Account for kernel size
1166 pad_shape[pad_offset] = k_shape[idx] - 1
1167 else:
1168 # Always have a large stride with extra padding and dilation to keep
1169 # tensor calculation reasonable
1170 stride_shape = [bigKernel] * k_rank
1171 for idx in range(k_rank):
1172 # Dilation shape must account for kernel size
1173 dilation_shape[idx] = bigKernel // k_shape[idx]
1174 # Padding shape needs to accommodate tensor/kernel & dilation
1175 pad_offset = idx * 2
1176 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1177 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1178
1179 strides = {tuple(stride_shape)}
1180 dilations = {tuple(dilation_shape)}
1181 paddings = {tuple(pad_shape)}
1182 # Create a limit for the output dimensions size
1183 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1184
1185 # Currently allow all combinations that are reasonable size
1186 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001187
1188 n = 0
1189 for s in sorted(list(strides)):
1190 for p in sorted(list(paddings)):
1191 for d in sorted(list(dilations)):
1192 if (
1193 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001194 # the padded shape must exceed the dilation * kernel to get a positive
1195 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001196 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
1197 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01001198 and (
1199 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01001200 or (
1201 (ifm_shape[3] - 1 + p[4] + p[5])
1202 > d[2] * (k_shape[2] - 1)
1203 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01001204 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001205 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001206 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01001207 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001208 for index in range(k_rank):
1209 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001210 partial = (
1211 ifm_shape[index + 1]
1212 - 1
1213 + p[pad_offset]
1214 + p[pad_offset + 1]
1215 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001216 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001217 remainders.append(partial % s[index])
1218 outputs.append((partial // s[index]) + 1)
1219
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001220 if (
1221 # the parameters must produce integer exact output
1222 error_name != ErrorIf.ConvOutputShapeNonInteger
1223 and max(remainders) == 0
1224 ) or (
1225 error_name == ErrorIf.ConvOutputShapeNonInteger
1226 and max(remainders) > 0
1227 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001228 if (
1229 max_dim_size is not None
1230 and max(outputs) >= max_dim_size
1231 ):
1232 # Test will consume too much memory - skip it
1233 continue
1234
1235 # Support for larger values than 9 needs different delimiter
1236 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001237 arg_list.append(
1238 (
James Ward8b390432022-08-12 20:48:56 +01001239 "acc{}_st{}_pad{}_dilat{}".format(
1240 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001241 delim.join([str(x) for x in s]),
1242 delim.join([str(x) for x in p]),
1243 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001244 ),
James Ward8b390432022-08-12 20:48:56 +01001245 [accum_dtype, s, p, d],
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001246 )
1247 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001248 n += 1
1249
1250 return arg_list
1251
1252 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001253 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1254
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001255 assert isinstance(dtypes, list) or isinstance(
1256 dtypes, tuple
1257 ), f"{dtypes} unexpected"
1258 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01001259
1260 if error_name == ErrorIf.WrongOutputType:
1261 accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype)
1262 elif error_name == ErrorIf.WrongInputType:
1263 # Pick some potentially correct output dtype if input type is incorrect
1264 accum_dtype = DType.INT32
1265 else:
1266 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1267
1268 return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])]
1269
1270 @staticmethod
1271 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
1272 # Get valid accumulate type(s)
1273 if dtype == DType.INT8:
1274 accum_dtypes = [DType.INT32]
1275 elif dtype == DType.INT16:
1276 accum_dtypes = [DType.INT48]
1277 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001278 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001279 elif dtype == DType.BF16:
1280 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001281 elif dtype == DType.FP32:
1282 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001283 elif error_name is None:
1284 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
1285
1286 if error_name == ErrorIf.WrongOutputType:
1287 # Get incorrect output dtype for ErrorIf case
1288 accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)]
1289 elif error_name == ErrorIf.WrongInputType:
1290 # Pick some potentially correct output dtype if input type is incorrect
1291 accum_dtypes = [DType.INT32]
1292
1293 return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes]
1294
1295 @staticmethod
1296 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001297 arg_list = []
1298
Jeremy Johnson0c716862023-04-13 17:18:19 +01001299 if testGen.args.level8k and error_name is not None:
1300 # Don't produce negative large tests
1301 return arg_list
1302
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001303 ifm_shape = shapeList[0]
1304 filter_shape = shapeList[1]
1305
James Ward8b390432022-08-12 20:48:56 +01001306 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1307
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001308 # Must be rank 4
1309 if error_name != ErrorIf.WrongRank:
1310 assert len(ifm_shape) == 4
1311 assert len(filter_shape) == 4
1312
Jeremy Johnson0c716862023-04-13 17:18:19 +01001313 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001314
Jeremy Johnson0c716862023-04-13 17:18:19 +01001315 if not testGen.args.level8k:
1316 # Generate comprehensive argument lists
1317 # - except for named errors, which use specific invalid value(s)
1318 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
1319 if error_name == ErrorIf.PadLargerEqualKernel:
1320 max_filter_size = -max(k_shape[0], k_shape[1])
1321 p_vals = [
1322 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
1323 ]
1324 else:
1325 p_vals = [
1326 x
1327 for x in range(
1328 smallest_padding_size, testGen.args.max_conv_padding + 1
1329 )
1330 ]
1331 paddings = {x for x in itertools.product(*([p_vals] * 4))}
1332 if error_name == ErrorIf.StrideSmallerOne:
1333 # Can't use stride=0, as it is used to derive output shape, as a divisor
1334 s_vals = [testGen.rng.choice(range(-5, 0))]
1335 else:
1336 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
1337 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001338
Jeremy Johnson0c716862023-04-13 17:18:19 +01001339 if not error_name and testGen.args.oversize:
1340 # add some oversize argument values
1341 if max(ifm_shape) < 64:
1342 bigPadding = 9
1343 paddings.update(
1344 {
1345 x
1346 for x in itertools.product(
1347 *([[smallest_padding_size, bigPadding]] * 4)
1348 )
1349 }
1350 )
1351 bigStride = 8
1352 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
1353
1354 # There are too many parameter combinations, so generate them sparsely,
1355 # very sparse for negative tests
1356 sparsity_factor = 2 if error_name else 10
1357 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
1358 # If there are only a small number of tests, just select them all
1359 if sparsity < 13:
1360 sparsity = 1
1361 # To get a variety of parameter combinations sparsity should not be a
1362 # multiple of 2, 3 or 5
1363 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1364 sparsity += 1
1365 else:
1366 # Only test 8k levels boundaries
1367 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1368 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1369 bigPadding = bigKernel
1370
1371 pad_shape = [0] * (len(k_shape) * 2)
1372 stride_shape = [1] * len(k_shape)
1373 # The point at which input dimension combined with the stride will
1374 # create large output sizes!
1375 LARGE_SIZE = 2
1376 for idx in range(len(k_shape)):
1377 pad_offset = idx * 2
1378 if k_shape[idx] == bigKernel:
1379 # Set large stride
1380 stride_shape[idx] = bigKernel
1381 # Use negative output padding to reduce shape size
1382 pad_shape[pad_offset] = -(bigPadding - 1)
1383 if ifm_shape[idx + 1] > LARGE_SIZE:
1384 pad_shape[pad_offset + 1] = -(bigPadding - 1)
1385 else:
1386 # The other dimension should be the bigKernel
1387 alt_idx = 1 - idx
1388 if (
1389 k_shape[alt_idx] == bigKernel
1390 and ifm_shape[alt_idx + 1] < LARGE_SIZE
1391 ):
1392 # As the input is small, the large stride won't
1393 # affect the output so we can add some padding
1394 pad_shape[pad_offset + 1] = bigPadding
1395
1396 strides = {tuple(stride_shape)}
1397 paddings = {tuple(pad_shape)}
1398
1399 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001400 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001401
1402 n = 0
1403 for s in sorted(list(strides)):
1404 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07001405 if n % sparsity == 0:
1406 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001407 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
1408 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07001409 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001410
1411 # Support for larger values than 9 needs different delimiter
1412 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07001413 arg_list.append(
1414 (
James Ward8b390432022-08-12 20:48:56 +01001415 "acc{}_st{}_pad{}_os{}".format(
1416 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001417 delim.join([str(x) for x in s]),
1418 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07001419 "x".join([str(x) for x in os]),
1420 ),
James Ward8b390432022-08-12 20:48:56 +01001421 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001422 )
TatWai Chong24594f52022-06-08 00:48:04 -07001423 )
1424 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001425
1426 return arg_list
1427
1428 @staticmethod
1429 def agPad(testGen, opName, shapeList, dtype, error_name=None):
1430 arg_list = []
1431 rank = len(shapeList[0])
1432
1433 # Exhaustively test combinations of padding on each side of each dimension
1434 # - the range of padding values is defined by pad_min and pad_max
1435 # - for padding >9, the name format needs to be more distinctive
1436 pad_min, pad_max = 0, 1
1437 pad_values = [x for x in range(pad_min, pad_max + 1)]
1438 if error_name == ErrorIf.PadSmallerZero:
1439 pad_values = [x for x in range(-2, 0)]
1440 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
1441 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
1442
1443 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
1444 pad_const_int = testGen.getRandNumberDType(dtype)
1445 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00001446 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001447 pad_const_int = 0
1448 pad_const_fp = testGen.getRandNumberDType(dtype)
1449 else:
1450 return []
1451
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001452 list_shape_pad_values = list(shape_pad_values)
1453 # If we are producing tests for rank 6 or greater use sparsity
1454 if len(list_shape_pad_values) > 1024:
1455 sparsity_factor = 2 if error_name else 120
1456 sparsity = TosaArgGen._calculate_sparsity(
1457 len(list_shape_pad_values), sparsity_factor
1458 )
1459 else:
1460 sparsity = 1
1461
1462 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01001463 paddings = list(paddings)
1464 args_valid = True
1465
1466 if error_name == ErrorIf.PadSmallerZero:
1467 # Prevent negative output shapes while ensuring still testing for negative padding
1468 for i in range(rank):
1469 dim_after_padding = (
1470 paddings[i][0] + paddings[i][1] + shapeList[0][i]
1471 )
1472 if dim_after_padding < 1:
1473 paddings[i] = (0, 0)
1474 if all([p > -1 for p in paddings[i]]):
1475 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001476 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01001477 name = "pad"
1478 for r in range(rank):
1479 before, after = paddings[r]
1480 name = f"{name}{before}{after}"
1481 arg_list.append(
1482 (name, [np.array(paddings), pad_const_int, pad_const_fp])
1483 )
1484
1485 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
1486 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001487
1488 return arg_list
1489
1490 @staticmethod
1491 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
1492 arg_list = []
1493
1494 shape = shapeList[0]
1495 if error_name != ErrorIf.WrongRank:
1496 assert len(shape) == 4
1497
Jeremy Johnson0c716862023-04-13 17:18:19 +01001498 test_level8k = testGen.args.level8k and error_name is None
1499
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001500 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001501 startKernel = 2
1502 startPad = 0
1503 if not test_level8k:
1504 # Generate comprehensive argument lists
1505 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
1506 paddings = {x for x in itertools.product(*([p_vals] * 4))}
1507 # Stride must be greater than 1 to force non-integer error
1508 s_vals = [
1509 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
1510 ]
1511 strides = {x for x in itertools.product(*([s_vals] * 2))}
1512 k_vals = [
1513 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
1514 ]
1515 kernels = {x for x in itertools.product(*([k_vals] * 2))}
1516 max_dim_size = None
1517 else:
1518 # Only test 8k levels
1519 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1520 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1521 strides = {(1, bigStride), (bigStride, 4)}
1522 kernels = {(1, bigKernel), (bigKernel, 3)}
1523 paddings = set()
1524 for s in sorted(list(strides)):
1525 for k in sorted(list(kernels)):
1526 padding = []
1527 for idx in range(len(k)):
1528 total_padding = s[idx] - shape[idx + 1] + k[idx]
1529 while total_padding < 0:
1530 # Must meet: shape + padding > kernel
1531 total_padding += s[idx]
1532 if total_padding < k[idx]:
1533 padding.extend([0, total_padding])
1534 else:
1535 # Note this may produce padding >= k[idx] which is not
1536 # allowed - but will be ignored in the creation loop below
1537 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
1538 paddings.add(tuple(padding))
1539 # Create a limit for the output dimensions size
1540 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001541
James Ward8b390432022-08-12 20:48:56 +01001542 if opName == "max_pool2d":
1543 accum_dtypes = [None] # max_pool has no accumulate dtype
1544 elif dtype == DType.INT8 or dtype == DType.INT16:
1545 accum_dtypes = [DType.INT32]
1546 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001547 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001548 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001549 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001550 elif error_name is None:
1551 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
1552 else:
1553 # Set to something for the ErrorIf case which has
1554 # incorrect input data-type
1555 accum_dtypes = [DType.INT32]
1556
Jeremy Johnson0c716862023-04-13 17:18:19 +01001557 if not test_level8k:
1558 if testGen.args.oversize:
1559 # add some oversize argument values
1560 bigStride = 7
1561 bigKernel = 9
1562 strides.update(
1563 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001564 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001565 kernels.update(
1566 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
1567 )
1568 if max(shape) < 64:
1569 # padding must be less than the kernel size
1570 bigPadding = bigKernel - 1
1571 paddings.update(
1572 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
1573 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001574
Jeremy Johnson0c716862023-04-13 17:18:19 +01001575 # There are too many parameter combinations, so generate them sparsely,
1576 # very sparse for negative tests
1577 sparsity_factor = 2 if error_name else 500
1578 sparsity = (
1579 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
1580 )
1581 else:
1582 # We have already limited test output combinations for 8k tests
1583 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001584
James Ward8b390432022-08-12 20:48:56 +01001585 arg_str = (
1586 "acc{}_st{}_kern{}_pad{}"
1587 if accum_dtypes[0] is not None
1588 else "st{}_kern{}_pad{}"
1589 )
1590
1591 def get_arg_list_element(accum, stride, pad, kern):
1592 # Return tuple containing the formatted argument string and
1593 # the corresponding argument values
Jeremy Johnson0c716862023-04-13 17:18:19 +01001594
1595 # Support for larger values than 9 needs different delimiter
1596 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01001597 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01001598 delim.join([str(x) for x in stride]),
1599 delim.join([str(x) for x in kern]),
1600 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01001601 ]
1602 # Note: different order to string
1603 arg_val_elems = [stride, pad, kern]
1604
1605 if accum is not None:
1606 arg_str_elems.insert(0, testGen.typeStr(accum))
1607 arg_val_elems.insert(0, accum)
1608 return (arg_str.format(*arg_str_elems), arg_val_elems)
1609
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001610 n = 0
James Ward8b390432022-08-12 20:48:56 +01001611 for a in accum_dtypes:
1612 for s in sorted(list(strides)):
1613 for p in sorted(list(paddings)):
1614 for k in sorted(list(kernels)):
1615 if error_name in [
1616 ErrorIf.StrideSmallerOne,
1617 ErrorIf.KernelSmallerOne,
1618 ErrorIf.PadSmallerZero,
1619 ErrorIf.PadLargerEqualKernel,
1620 ]:
1621 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
1622 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001623 )
James Ward8b390432022-08-12 20:48:56 +01001624 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
1625 arg_vals = [a, sNew, pNew, kNew]
1626 arg_list.append(get_arg_list_element(*arg_vals))
1627 elif (
1628 n % sparsity == 0
1629 # padding must not exceed the kernel size
1630 and p[0] < k[0]
1631 and p[1] < k[0]
1632 and p[2] < k[1]
1633 and p[3] < k[1]
1634 # the padded shape must exceed the kernel size
1635 and (shape[1] + p[0] + p[1]) > k[0]
1636 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001637 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001638 partial_h = shape[1] + p[0] + p[1] - k[0]
1639 partial_w = shape[2] + p[2] + p[3] - k[1]
1640 remainder_h = partial_h % s[0]
1641 remainder_w = partial_w % s[1]
1642 output_h = partial_h // s[0] + 1
1643 output_w = partial_w // s[1] + 1
1644 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01001645 if (
1646 # the parameters must produce integer exact output
1647 error_name != ErrorIf.PoolingOutputShapeNonInteger
1648 and remainder_h == 0
1649 and remainder_w == 0
1650 ) or (
1651 error_name == ErrorIf.PoolingOutputShapeNonInteger
1652 and (remainder_h != 0 or remainder_w != 0)
1653 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001654 if (
1655 max_dim_size is not None
1656 and max(output_h, output_w) > max_dim_size
1657 ):
1658 # Test will consume too much memory - skip it
1659 continue
James Ward8b390432022-08-12 20:48:56 +01001660 arg_vals = [a, s, p, k]
1661 arg_list.append(get_arg_list_element(*arg_vals))
1662 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001663
1664 return arg_list
1665
1666 @staticmethod
1667 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
1668 arg_list = []
1669
1670 # Enumerate the output types here
1671 if error_name == ErrorIf.WrongOutputType:
1672 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
1673 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00001674 dtypeList = [
1675 DType.BOOL,
1676 DType.INT16,
1677 DType.INT32,
1678 DType.FP16,
1679 DType.BF16,
1680 DType.FP32,
1681 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001682 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00001683 dtypeList = [
1684 DType.BOOL,
1685 DType.INT8,
1686 DType.INT32,
1687 DType.FP16,
1688 DType.BF16,
1689 DType.FP32,
1690 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001691 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00001692 dtypeList = [
1693 DType.BOOL,
1694 DType.INT8,
1695 DType.INT16,
1696 DType.FP16,
1697 DType.BF16,
1698 DType.FP32,
1699 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001700 elif inDtype == DType.BOOL:
1701 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01001702 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00001703 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001704 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00001705 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001706 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00001707 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001708 elif error_name == ErrorIf.WrongInputType:
1709 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001710 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001711 else:
1712 raise Exception("Unexpected input dtype: {}".format(inDtype))
1713
1714 for dtype in dtypeList:
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001715 arg_list.append(("out{}".format(testGen.typeStr(dtype)), [dtype]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001716
1717 return arg_list
1718
1719 @staticmethod
1720 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
1721 arg_list = []
1722
1723 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001724 for outDtype in [
1725 DType.UINT8,
1726 DType.INT8,
1727 DType.INT16,
1728 DType.INT32,
1729 DType.UINT16,
1730 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001731 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001732 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001733 and error_name == ErrorIf.OutputZeroPointNotZero
1734 ):
1735 continue
1736 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001737 outDtype != DType.UINT16
1738 and error_name == ErrorIf.U16OutputZeroPointNotValid
1739 ) or (
1740 inDtype != DType.UINT16
1741 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001742 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001743 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001744 continue
1745 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001746 inDtype == DType.UINT8
1747 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001748 and error_name != ErrorIf.WrongOutputType
1749 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001750 # The only output dtypes for UINT8 are INT8/INT16, skip all others
1751 continue
1752 if (
1753 inDtype not in [DType.INT8, DType.INT16]
1754 and outDtype == DType.UINT8
1755 and error_name != ErrorIf.WrongOutputType
1756 ):
1757 # The only input dtypes for UINT8 are INT8/INT16, skip all others
1758 continue
1759 if (
1760 inDtype == DType.UINT16
1761 and outDtype != DType.INT16
1762 and error_name != ErrorIf.WrongOutputType
1763 ):
1764 # The only output dtype for UINT16 is INT16, skip all others
1765 continue
1766 if (
1767 inDtype != DType.INT16
1768 and outDtype == DType.UINT16
1769 and error_name != ErrorIf.WrongOutputType
1770 ):
1771 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001772 continue
1773 if (
1774 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001775 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001776 ):
1777 continue
1778
1779 for scale32 in [False, True]:
1780 if error_name == ErrorIf.ScaleTrue and not scale32:
1781 continue
1782 elif error_name == ErrorIf.ScaleNotTrue and scale32:
1783 continue
1784 for double_round in [False, True]:
1785 if error_name == ErrorIf.ScaleNotTrue and not double_round:
1786 continue
1787 for per_channel in [False, True]:
1788
1789 if (
1790 inDtype == DType.INT48
1791 and scale32
1792 and error_name != ErrorIf.ScaleTrue
1793 ):
1794 # Illegal condition. Must be scale32=False
1795 continue
1796 if (
1797 double_round
1798 and not scale32
1799 and error_name != ErrorIf.ScaleNotTrue
1800 ):
1801 # Illegal condition. ERROR_IF(!scale32 && double_round)
1802 continue
1803
1804 arg_list.append(
1805 (
1806 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001807 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001808 int(scale32),
1809 int(double_round),
1810 int(per_channel),
1811 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001812 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001813 )
1814 )
1815
1816 return arg_list
1817
1818 @staticmethod
1819 def agMul(testGen, opName, shapeList, dtype, error_name=None):
1820 arg_list = []
1821
1822 if dtype is DType.INT32:
1823 for p in range(testGen.args.num_rand_permutations):
1824
1825 shift = testGen.randInt(0, 32)
1826
1827 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
1828 else:
1829 arg_list.append(("perm0_shift0", [0]))
1830
1831 return arg_list
1832
1833 @staticmethod
1834 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
1835 arg_list = []
1836
1837 arg_list.append(("roundTrue", [True]))
1838 arg_list.append(("roundFalse", [False]))
1839
1840 return arg_list
1841
Luke Hutton57287132023-02-06 14:54:18 +00001842 @staticmethod
1843 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
1844 arg_list = []
1845
1846 arg_list.append(("inverseTrue", [True]))
1847 arg_list.append(("inverseFalse", [False]))
1848
1849 return arg_list
1850
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001851 # Helper function for reshape. Gets some factors of a larger number.
1852 @staticmethod
1853 def getFactors(val, start=1):
1854 factors = []
1855
1856 for i in range(start, int(np.sqrt(val)) + 1):
1857 if (val % i) == 0:
1858 factors.append(i)
1859
1860 return factors
1861
1862 @staticmethod
1863 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
1864 arg_list = []
1865
1866 origShape = shapeList[0]
1867
1868 totalElements = 1
1869 for s in origShape:
1870 totalElements *= s
1871
1872 # This code is NOT fast. Fortunately, the numbers are fairly small.
1873 factors = TosaArgGen.getFactors(totalElements)
1874
1875 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001876 # Rank from 1 to TOSA_TENSOR_MAX_RANK
1877 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001878 if len(factors) < newRank:
1879 continue
1880
1881 found = True
1882 # escape_counter breaks while loop if it continues on for too long
1883 escape_counter = 0
1884 while found:
1885 newShape = []
Jerry Ge264f7fa2023-04-21 22:49:57 +00001886 new_shape_inferred = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001887 # Generate newShape ensuring it isn't a duplicate
1888 remainingElements = totalElements
1889 shuffledFactors = testGen.rng.permutation(factors)
Jerry Ge264f7fa2023-04-21 22:49:57 +00001890 inferred_dim = testGen.rng.integers(1, newRank + 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001891 for i in range(1, newRank):
1892 # pick rank-1 factors
1893 newShape.append(shuffledFactors[0])
1894 remainingElements = remainingElements // shuffledFactors[0]
Jerry Ge264f7fa2023-04-21 22:49:57 +00001895 if i == inferred_dim:
1896 new_shape_inferred.append(-1)
1897 else:
1898 new_shape_inferred.append(shuffledFactors[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001899 shuffledFactors = testGen.rng.permutation(
1900 TosaArgGen.getFactors(remainingElements)
1901 )
1902 newShape.append(remainingElements)
Jerry Ge264f7fa2023-04-21 22:49:57 +00001903 if inferred_dim == newRank:
1904 new_shape_inferred.append(-1)
1905 else:
1906 new_shape_inferred.append(remainingElements)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001907
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001908 # Check for duplicates
1909 found = False
1910 for name, other_shape in arg_list:
1911 if other_shape[0] == newShape:
1912 found = True
1913 break
1914
1915 escape_counter += 1
1916 if escape_counter >= 100:
1917 break
1918
1919 if not found:
Jerry Ge264f7fa2023-04-21 22:49:57 +00001920 if error_name in [
1921 ErrorIf.ReshapeOutputSizeNonInteger,
1922 ErrorIf.ReshapeOutputSizeMultiInference,
1923 ]:
1924 if newRank < 2:
1925 # Need at least two dimensions
1926 continue
1927 # NOTE: Change inferred_dim starting offset from 1 to 0
1928 inferred_dim -= 1
1929 extra_dim = inferred_dim + testGen.rng.integers(1, newRank)
1930 extra_dim = extra_dim % newRank
1931 assert extra_dim != inferred_dim
1932 if error_name == ErrorIf.ReshapeOutputSizeNonInteger:
1933 elements = 1
1934 for i, dim_value in enumerate(new_shape_inferred):
1935 if i != inferred_dim and i != extra_dim:
1936 elements *= dim_value
1937 dim_value = new_shape_inferred[extra_dim]
1938 while totalElements % (elements * dim_value) == 0:
1939 dim_value += 1
1940 new_shape_inferred[extra_dim] = dim_value
1941 else:
1942 assert error_name == ErrorIf.ReshapeOutputSizeMultiInference
1943 new_shape_inferred[extra_dim] = -1
1944 else:
1945 arg_list.append(
1946 ("perm{}_rank{}_outdefined".format(p, newRank), [newShape])
1947 )
1948 if error_name != ErrorIf.TensorSizeInputOutputMismatch:
1949 arg_list.append(
1950 (
1951 "perm{}_rank{}_outinferred".format(p, newRank),
1952 [new_shape_inferred],
1953 )
1954 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001955
1956 return arg_list
1957
1958 @staticmethod
1959 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
1960 arg_list = []
1961
1962 ifm_shape = shapeList[0]
1963
1964 if error_name == ErrorIf.IndexOutsideBounds:
1965 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
1966 incorrect_small_index = range(-len(ifm_shape), 0)
1967 permutations = [p for p in itertools.permutations(incorrect_large_index)]
1968 permutations.extend(
1969 [p for p in itertools.permutations(incorrect_small_index)]
1970 )
1971 elif error_name == ErrorIf.IndexUsedTwice:
1972 # Create list with a duplicated index
1973 perm_range = list(range(len(ifm_shape)))
1974 index_choice = testGen.rng.choice(range(len(perm_range)))
1975 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
1976 permutations = [p for p in itertools.permutations(perm_range)]
1977
1978 else:
1979 # Get all permutations
1980 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
1981
1982 # Limit to possible permutations from shape dimension or argument setting
1983 limit = min(len(permutations), testGen.args.num_rand_permutations)
1984
1985 # Get random permutation generator that uses all permutations
1986 random_permutations = testGen.rng.permutation(permutations)
1987
1988 # Create list of required amount of permutations
1989 arg_list = [
1990 ("perm{}".format(p), [random_permutations[p].tolist()])
1991 for p in range(limit)
1992 ]
1993 return arg_list
1994
1995 @staticmethod
1996 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
1997 arg_list = []
1998
1999 ifm_shape = shapeList[0]
2000 rank = len(ifm_shape)
2001
2002 for p in range(testGen.args.num_rand_permutations):
2003 start = []
2004 size = []
2005
2006 valid = True
2007
2008 for i in range(rank):
2009 if ifm_shape[i] > 1:
2010 start.append(testGen.randInt(0, ifm_shape[i]))
2011 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
2012
2013 # Invalid slice size?
2014 if size[i] == 0:
2015 valid = False
2016 else:
2017 start.append(0)
2018 size.append(1)
2019
2020 if valid:
2021 # If ERROR_IF test required then incorrect start, size will be returned
2022 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
2023 testGen, error_name, ifm_shape, start, size
2024 )
2025 arg_list.append(("perm{}".format(p), [start, size]))
2026 return arg_list
2027
2028 @staticmethod
2029 def agTile(testGen, opName, shapeList, dtype, error_name=None):
2030 arg_list = []
2031
2032 ifm_shape = shapeList[0]
2033 rank = len(ifm_shape)
2034
2035 for p in range(testGen.args.num_rand_permutations):
2036
2037 # Pick a few random, but small multiple values
2038 # because otherwise this has a tendency to generate
2039 # enormous tensors
2040 multiples = []
2041 for i in range(rank):
2042 if ifm_shape[i] > 1000:
2043 # Multiple of 1 if ifm_shape dimension is large to reduce
2044 # tensor size
2045 multiples.append(1)
2046 elif max(ifm_shape) > 1000:
2047 multiples.append(2)
2048 else:
2049 multiples.append(testGen.randInt(1, 4))
2050 arg_list.append(("perm{}".format(p), [multiples]))
2051
2052 return arg_list
2053
2054 @staticmethod
2055 def agResize(testGen, opName, shapeList, dtype, error_name=None):
2056 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002057 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002058
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002059 def get_aspect_ratio_resize_params():
2060 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
2061 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
2062 invert = testGen.rng.choice((False, True))
2063 letterbox = testGen.rng.choice((False, True))
2064
2065 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
2066 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
2067 scale_y_d = scale_x_d = 1
2068 offset_x = offset_y = 0
2069
2070 if letterbox:
2071 max_border = scale_y_n
2072 border_y = testGen.randInt(low=0, high=max_border)
2073 border_x = 0
2074 else:
2075 # Pillarboxing
2076 border_y = 0
2077 max_border = scale_x_n
2078 border_x = testGen.randInt(low=0, high=max_border)
2079
2080 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2081 offset = (offset_y, offset_x)
2082 border = (border_y, border_x)
2083
2084 return scale, offset, border
2085
2086 def get_upscale_downscale_params():
2087 valid_params = False
2088 while not valid_params:
2089 upscale = testGen.rng.choice((False, True))
2090
2091 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
2092 origin_sampling = testGen.rng.choice((False, True))
2093
2094 if upscale:
2095 shift = testGen.randInt(low=1, high=4)
2096 scale_x_d = scale_y_d = 1
2097 scale_x_n = scale_y_n = (
2098 1 << shift if origin_sampling else 2 << shift
2099 )
2100 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
2101 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
2102 else:
2103 scale_x_n = 1
2104 scale_y_n = 1
2105
2106 # Return list of valid scale_*_d values (max value 4) given input dim shape
2107 def get_valid_denom(ifm_dim):
2108 return [x for x in range(1, 5) if ifm_dim % x == 1]
2109
2110 # Generate list of valid downscale values and choose one randomly
2111 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
2112 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
2113
2114 if not valid_scale_y_ds and not valid_scale_x_ds:
2115 # Bad parameters, skip
2116 continue
2117
2118 if not valid_scale_y_ds:
2119 scale_y_d = 1
2120 else:
2121 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
2122
2123 if not valid_scale_x_ds:
2124 scale_x_d = 1
2125 else:
2126 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
2127
2128 border_x = border_y = 0
2129 offset_y = testGen.randInt(0, 16 * scale_y_n)
2130 offset_x = testGen.randInt(0, 16 * scale_x_n)
2131 valid_params = True
2132
2133 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2134 offset = (offset_y, offset_x)
2135 border = (border_y, border_x)
2136 return scale, offset, border
2137
2138 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002139 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
2140 scale = scale_n / scale_d
2141 if scale > max_scale:
2142 factor = scale / max_scale
2143 new_scale_d = math.ceil(scale_d * factor)
2144 assert scale_n / new_scale_d <= max_scale
2145 scale_d = new_scale_d
2146 return scale_d
2147
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002148 # Scale
2149 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
2150 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
2151
2152 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
2153 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
2154
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002155 scale_y_d = fix_scale_to_max_scale(
2156 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2157 )
2158 scale_x_d = fix_scale_to_max_scale(
2159 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2160 )
2161
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002162 # Offsets and border within the scale
2163 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
2164 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
2165 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
2166 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
2167
2168 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2169 offset = (offset_y, offset_x)
2170 border = (border_y, border_x)
2171 return scale, offset, border
2172
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002173 def get_level_8k_params():
2174 # Create 64x scale - 64/1 to 2048/32
2175 scale_d = testGen.randInt(
2176 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
2177 )
2178 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
2179 # Create half to fifth scaling
2180 scale_d_alt = testGen.randInt(low=2, high=6)
2181 scale_n_alt = 1
2182 switch = testGen.rng.choice((False, True))
2183 if switch:
2184 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
2185 else:
2186 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
2187
2188 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
2189 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
2190 offset = (offset_y, offset_x)
2191 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
2192 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
2193 border = (border_y, border_x)
2194 return scale, offset, border
2195
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002196 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002197 # Exclude illegal {mode, type} configurations. Pick legal output types
2198 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
2199 outputDTypeList = [DType.INT8]
2200 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
2201 outputDTypeList = [DType.INT16]
2202 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
2203 outputDTypeList = [DType.INT32]
2204 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
2205 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01002206 elif dtype == DType.FP16:
2207 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01002208 elif dtype == DType.BF16:
2209 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002210 elif dtype == DType.FP32:
2211 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002212 elif error_name == ErrorIf.WrongInputType:
2213 # If an incorrect input type is used then we set a 'correct'
2214 # output type to avoid other errors
2215 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
2216 else:
2217 continue
2218
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002219 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
2220
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002221 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002222 perm = 0
2223 while perm < testGen.args.num_rand_permutations:
2224 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002225 if not testGen.args.level8k:
2226 _rnd_param_fn = testGen.rng.choice(
2227 (
2228 get_rand_params,
2229 get_upscale_downscale_params,
2230 get_aspect_ratio_resize_params,
2231 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002232 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002233 scale, offset, border = _rnd_param_fn()
2234 else:
2235 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002236
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002237 # Expand params for bounds-checking
2238 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
2239 (offset_y, offset_x) = offset
2240 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002241
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002242 # Make sure output dimensions OH and OW are integers
2243 partial_output_y = (
2244 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
2245 )
2246 partial_output_x = (
2247 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
2248 )
2249 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002250 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002251 if (
2252 partial_output_y % scale_y_d == 0
2253 and partial_output_x % scale_x_d == 0
2254 ):
2255 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002256 if perm > 0:
2257 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002258 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002259 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002260 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002261 while partial_output_y % scale_y_d != 0:
2262 scale_y_d -= 1
2263 while partial_output_x % scale_x_d != 0:
2264 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002265 # Make sure we are still within max scaling
2266 if (
2267 scale_y_n / scale_y_d
2268 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
2269 scale_x_n / scale_x_d
2270 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
2271 # Skip the test as it is using too large a scaling factor
2272 if perm > 0:
2273 perm += 1
2274 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002275
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002276 output_y = partial_output_y // scale_y_d + 1
2277 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002278
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002279 if (
2280 output_y >= testGen.args.max_resize_output_dim
2281 or output_x >= testGen.args.max_resize_output_dim
2282 ) and error_name is None:
2283 # Skip positive test if output dim will be too high
2284 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002285 if not testGen.args.level8k or perm > 0:
2286 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002287 continue
2288
2289 if (
2290 output_y <= 0
2291 or output_y >= MAX_RESIZE_DIMENSION
2292 or output_x <= 0
2293 or output_x >= MAX_RESIZE_DIMENSION
2294 ):
2295 # Output dimensions out of scope
2296 if error_name is not None and perm > 0:
2297 # As long as we have one ERROR_IF test, don't worry
2298 # about creating all the other permutations
2299 perm += 1
2300 continue
2301
2302 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
2303 (
2304 output_y + scale_y_d >= MAX_RESIZE_DIMENSION
2305 and output_y - scale_y_d < 1
2306 )
2307 or (
2308 output_x + scale_x_d >= MAX_RESIZE_DIMENSION
2309 and output_x - scale_x_d < 1
2310 )
2311 ):
2312 # Can't create a negative test with these params as it
2313 # will create invalid output size
2314 if perm > 0:
2315 perm += 1
2316 continue
2317
2318 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
2319 offset = [offset_y, offset_x]
2320 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002321
2322 # Common for all data types
2323 if error_name is not None:
2324 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002325 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002326 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002327 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002328 outputDTypeNew,
2329 ) = TosaErrorIfArgGen.eiResizeErrorIf(
2330 testGen,
2331 error_name,
2332 mode,
2333 dtype,
2334 shapeList,
2335 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002336 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002337 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002338 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002339 )
2340 else:
2341 outputDTypeNew = outputDType
2342
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002343 arg_to_append = (
2344 arg_str.format(
2345 "N" if mode == ResizeMode.NEAREST else "B",
2346 testGen.typeStr(outputDTypeNew),
2347 scale[0],
2348 scale[1],
2349 scale[2],
2350 scale[3],
2351 offset[0],
2352 offset[1],
2353 border[0],
2354 border[1],
2355 ),
2356 [
2357 mode,
2358 scale,
2359 offset,
2360 border,
2361 dtype,
2362 outputDTypeNew,
2363 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002364 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002365 if arg_to_append in arg_list:
2366 # Skip already generated test params
2367 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002368
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002369 # Valid permutation
2370 perm += 1
2371 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002372 return arg_list
2373
2374 @staticmethod
2375 def agTable(testGen, opName, shapeList, dtype, error_name=None):
2376 arg_list = []
2377
2378 if dtype == DType.INT8:
2379 table = np.int32(
2380 testGen.rng.integers(low=-128, high=128, size=[256])
2381 ).tolist()
2382 else: # INT16
2383 table = np.int32(
2384 testGen.rng.integers(low=-32768, high=32768, size=[513])
2385 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07002386 # Make sure all slopes are within REQUIRE min/max 16-bit int
2387 for idx in range(len(table) - 1):
2388 slope = table[idx + 1] - table[idx]
2389 # Alter the next table entry to force the slope to be ok
2390 if slope > 32767:
2391 table[idx + 1] -= slope - 32767
2392 if slope < -32768:
2393 table[idx + 1] -= slope + 32768
2394 slope = table[idx + 1] - table[idx]
2395 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002396 arg_list.append(
2397 (
2398 "",
2399 [table],
2400 )
2401 )
2402 return arg_list
2403
2404 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
2405 # CondIf generates the condition values here.
2406 # Convert to tensors in the build function, along with the
2407 # then and else blocks
2408 arg_list = []
2409
2410 for c in [False, True]:
2411 arg_list.append(("cond{}".format(int(c)), [c]))
2412
2413 return arg_list
2414
2415 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
2416 # While loop: 0 iterations, 1, more than 1
2417 arg_list = []
2418
2419 for iter in [0, 1, 4]:
2420 arg_list.append(("iter{}".format(iter), [iter]))
2421
2422 return arg_list