blob: de882cae2aedb0b999ed760c8f4d31f0e8fca8ed [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
19
20class TosaQuantGen:
21 """QuantizedInfo random generator helper functions.
22
23 Specify with 'qgen': in the operator defintion.
24 """
25
26 def __init__(self):
27 pass
28
29 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010031
32 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010033 if testGen.args.zeropoint is not None:
34 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035 return testGen.randInt(-128, 128)
36 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010037 if testGen.args.zeropoint is not None:
38 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 return testGen.randInt(0, 256)
40 elif error_name in [
41 ErrorIf.InputZeroPointNotZero,
42 ErrorIf.WeightZeroPointNotZero,
43 ErrorIf.OutputZeroPointNotZero,
44 ]:
45 zero_point = testGen.randInt(-128, 128)
46 if zero_point == 0:
47 zero_point = 1
48 return zero_point
49 return 0
50
51 @staticmethod
52 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010053 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000054 qinfo = [
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 TosaQuantGen.getZeroPoint(testGen, dtype),
57 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010058 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000059 qinfo = [
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
62 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010063 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000064 qinfo = [
65 TosaQuantGen.getZeroPoint(testGen, dtype),
66 TosaQuantGen.getZeroPoint(testGen, dtype),
67 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068 return qinfo
69
70 @staticmethod
71 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010072 if isinstance(dtype_or_dtypeList, list):
73 # a list of [input, weights, accumulator] dtypes
74 dtypeList = dtype_or_dtypeList
75 else:
76 # an int, [input, weights, accumulator] dtypes are the same
77 dtypeList = [dtype_or_dtypeList] * 3
78
79 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000080 qinfo = [
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
82 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
83 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000085 qinfo = [
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
87 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
88 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000090 qinfo = [
91 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
92 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
93 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 return qinfo
95
96 @staticmethod
97 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000099 qinfo = [
100 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
101 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
102 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000104 qinfo = [
105 TosaQuantGen.getZeroPoint(testGen, dtype),
106 TosaQuantGen.getZeroPoint(testGen, dtype),
107 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108 return qinfo
109
110 @staticmethod
111 def computeMultiplierAndShift(scaleFp, scale32):
112 # Derived from computeMultiplierAndShiftTosaScale32
113 # Provide a floating-point scaling factor and the scale32 parameter
114 # to compute the multiplier and shift
115
116 if scale32:
117 scaleBits = 31
118 else:
119 scaleBits = 15
120
121 m, shift = math.frexp(scaleFp)
122
123 if scaleFp < 0.0:
124 m = -m
125
126 multiplier = round(m * (1 << scaleBits))
127 assert multiplier <= (1 << scaleBits)
128
129 if multiplier == (1 << scaleBits):
130 multiplier = multiplier // 2
131 shift = shift + 1
132
133 shift = (-shift) + scaleBits
134 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
135 # scaleFp, scaleBits, m, multiplier, shift))
136
137 # Adjust multiplier such that shift is in allowed value range.
138 if shift == 0:
139 multiplier = multiplier // 4
140 shift = shift + 2
141 elif shift == 1:
142 multiplier = multiplier // 2
143 shift = shift + 1
144 elif shift == 63:
145 multiplier = multiplier * 2
146 shift = shift - 1
147
148 assert multiplier <= (1 << scaleBits)
149 assert shift >= 2 and shift <= 62
150
151 return multiplier, shift
152
153
154class TosaTensorGen:
155 """Tensor generators create a shape list for the placeholder and const tensor
156 data operands for the operator.
157
158 The actual random data is generated separately for each test.
159 """
160
161 def __init__(self):
162 pass
163
164 @staticmethod
165 def tgBasic(testGen, opName, rank, error_name=None):
166 pl, const = opName["operands"]
167 shape = testGen.makeShape(rank)
168
169 # Constrict the overall size of the shape when creating ERROR_IF tests
170 if error_name:
171 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
172
173 shape_list = []
174 for i in range(pl + const):
175 shape_list.append(shape.copy())
176
Luke Huttona4e48ca2023-02-22 11:53:48 +0000177 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 if error_name == ErrorIf.RankMismatch:
179 if rank == 1 and i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
181 elif i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
183
184 return shape_list
185
186 @staticmethod
187 def tgNHWC(testGen, opName, rank, error_name=None):
188 pl, const = opName["operands"]
189
190 if error_name != ErrorIf.WrongRank:
191 assert rank == 4
192
193 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000194 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name and error_name != ErrorIf.MaxDimExceeded:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
207 def tgScatter(testGen, opName, rank, error_name=None):
208 pl, const = opName["operands"]
209
210 assert pl == 2
211 assert const == 0
212 if error_name != ErrorIf.WrongRank:
213 assert rank == 3
214
215 values_in_shape = testGen.makeShape(rank)
216
217 # ignore max batch size if target shape is set
218 if testGen.args.max_batch_size and not testGen.args.target_shapes:
James Ward30124a82023-02-02 14:56:33 +0000219 values_in_shape[0] = min(values_in_shape[0], testGen.args.max_batch_size)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100220
221 W = testGen.randInt(
222 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
223 )
224 # Constrict W if one dimension is too large to keep tensor size reasonable
225 if max(values_in_shape) > 5000:
226 W = testGen.randInt(0, 16)
227
228 input_shape = [values_in_shape[0], W, values_in_shape[2]]
229
230 shape_list = []
231 shape_list.append(values_in_shape.copy())
232 shape_list.append(input_shape.copy())
233
234 return shape_list
235
236 @staticmethod
237 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
238 shape = testGen.makeShape(rank)
239
240 pl, const = op["operands"]
241
242 shape_list = []
243
244 # Choose one of the inputs to broadcast
245 # Note: Simplifies OutputShaper code if we don't change first shape for errors
246 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Jerry Ge135c9552023-05-23 20:59:32 +0000247 fuzz_idx = testGen.randInt(0, rank)
248
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100249 for i in range(pl + const):
250 shape_bcast = shape.copy()
251
Jerry Ge135c9552023-05-23 20:59:32 +0000252 # To test broadcasting, the chosen fuzz index dimension should not be 1
253 if shape_bcast[fuzz_idx] == 1:
254 shape_bcast[fuzz_idx] += 1
255
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100256 # If the chosen input, pick a random index to broadcast
257 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000258 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100259 # Add one rank to the shape (or more for rank of 1)
260 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
261 shape_bcast = np.concatenate(
262 (shape_bcast, testGen.makeShape(extra_ranks))
263 )
264 if rank != 1:
265 # Either keep the extra rank, or remove it
266 new_len = testGen.rng.choice([-2, len(shape_bcast)])
267 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000268 elif error_name == ErrorIf.BroadcastShapesMismatch:
269 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100270 else:
271 shape_bcast[fuzz_idx] = 1
272
273 shape_list.append(shape_bcast)
274
275 return shape_list
276
277 @staticmethod
278 def tgConv2D(testGen, op, rank, error_name=None):
279 pl, const = op["operands"]
280
281 if error_name != ErrorIf.WrongRank:
282 assert rank == 4
283
284 # IFM dimensions are NHWC
285 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000286 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100287
288 # Constrict the overall size of the shape when creating ERROR_IF tests
289 if error_name:
290 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
291 ifm_shape, max_dim=24, max_items=10000
292 )
293
294 # Get the filter height/width from the operator parameters
295 filter_hw = op["filter"]
296
297 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000298 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100299
300 # The filter dimensions are OHWI
301 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
302
303 # The bias is OC
304 bias_shape = np.asarray([ofm_depth])
305
306 return [ifm_shape, filter_shape, bias_shape]
307
308 @staticmethod
309 def tgConv3D(testGen, op, rank, error_name=None):
310 pl, const = op["operands"]
311
312 if error_name != ErrorIf.WrongRank:
313 assert rank == 5
314
315 # IFM dimensions are NDHWC
316 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000317 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100318
319 # Constrict the overall size of the shape when creating ERROR_IF tests
320 if error_name:
321 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
322 ifm_shape, max_dim=24, max_items=10000
323 )
324
325 # Get the filter depth/height/width from the operator parameters
326 filter_dhw = op["filter"]
327
328 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000329 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100330
331 # The filter dimensions are ODHWI
332 filter_shape = np.asarray(
333 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
334 )
335
336 # The bias is OC
337 bias_shape = np.asarray([ofm_channel])
338
339 return [ifm_shape, filter_shape, bias_shape]
340
341 @staticmethod
342 def tgTransposeConv2D(testGen, op, rank, error_name=None):
343 pl, const = op["operands"]
344
345 if error_name != ErrorIf.WrongRank:
346 assert rank == 4
347
348 # IFM dimensions are NHWC
349 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000350 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100351
352 # Constrict the overall size of the shape when creating ERROR_IF tests
353 if error_name:
354 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
355 ifm_shape, max_dim=24, max_items=10000
356 )
357
358 # Get the filter height/width from the operator parameters
359 filter_hw = op["filter"]
360
361 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000362 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100363
364 # The filter dimensions are OHWI
365 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
366
367 # The bias is OC
368 bias_shape = np.asarray([ofm_depth])
369
370 return [ifm_shape, filter_shape, bias_shape]
371
372 @staticmethod
373 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
374 pl, const = op["operands"]
375
376 if error_name != ErrorIf.WrongRank:
377 assert rank == 4
378 assert pl == 1 and const == 2
379
380 # IFM dimensions are NHWC
381 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000382 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100383
384 # Constrict the overall size of the shape when creating ERROR_IF tests
385 if error_name:
386 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
387 ifm_shape, max_dim=24, max_items=10000
388 )
389
390 # Get the filter height/width from the operator parameters
391 # Filter is KH, HW, C, M
392 filter_hw = op["filter"]
393
394 # Generate a random OFM depth, but don't let it get too big because
395 # the output depth is M * C
396 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000397 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100398 ) + 1
399
400 # The filter dimensions are HWCM
401 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
402
403 # The bias is M * C
404 bias_shape = np.asarray([ifm_shape[3] * filter_m])
405
406 return [ifm_shape, filter_shape, bias_shape]
407
408 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000409 def tgFFT2d(testGen, op, rank, error_name=None):
410 pl, const = op["operands"]
411
412 if error_name != ErrorIf.WrongRank:
413 assert rank == 3
414 assert pl == 2 and const == 0
415
416 # IFM dimensions are NHW
417 ifm_shape = testGen.makeShape(rank)
418
419 # Select nearest lower power of two from input height and width
420 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
421 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
422
423 # Constrict the overall size of the shape when creating ERROR_IF tests
424 if error_name:
425 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
426
427 # Generate an invalid kernel that is not a power of two
428 if error_name == ErrorIf.KernelNotPowerOfTwo:
429 inc_h = 2 if ifm_shape[1] == 1 else 1
430 inc_w = 2 if ifm_shape[2] == 1 else 1
431 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
432 selected_inc = testGen.rng.choice(inc_choices)
433 ifm_shape[1] += selected_inc[0]
434 ifm_shape[2] += selected_inc[1]
435
436 ifm_shape = testGen.constrictBatchSize(ifm_shape)
437
438 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
439 if error_name == ErrorIf.FFTInputShapeMismatch:
440 modify_shape = testGen.rng.choice([0, 1])
441 # Only modify kernel (H, W)
442 modify_dim = testGen.rng.choice([1, 2])
443 ifm_shapes[modify_shape][modify_dim] *= 2
444
445 return [ifm_shapes[0], ifm_shapes[1]]
446
447 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000448 def tgRFFT2d(testGen, op, rank, error_name=None):
449 pl, const = op["operands"]
450
451 if error_name != ErrorIf.WrongRank:
452 assert rank == 3
453 assert pl == 1 and const == 0
454
455 # IFM dimensions are NHW
456 ifm_shape = testGen.makeShape(rank)
457
458 # Select nearest lower power of two from input height and width
459 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
460 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
461
462 # Constrict the overall size of the shape when creating ERROR_IF tests
463 if error_name:
464 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
465
466 # Generate an invalid kernel that is not a power of two
467 if error_name == ErrorIf.KernelNotPowerOfTwo:
468 # We must increment by 2 if current size is 1
469 inc_h = 2 if ifm_shape[1] == 1 else 1
470 inc_w = 2 if ifm_shape[2] == 1 else 1
471 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
472 selected_inc = testGen.rng.choice(inc_choices)
473 ifm_shape[1] += selected_inc[0]
474 ifm_shape[2] += selected_inc[1]
475
James Ward30124a82023-02-02 14:56:33 +0000476 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000477
478 return [ifm_shape]
479
480 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100481 def tgFullyConnected(testGen, op, rank, error_name=None):
482 pl, const = op["operands"]
483
484 if error_name != ErrorIf.WrongRank:
485 assert rank == 2
486
487 input_shape = testGen.makeShape(rank)
488
489 # Constrict the overall size of the shape when creating ERROR_IF tests
490 if error_name:
491 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
492
493 filter_oc = testGen.rng.integers(
494 low=testGen.args.tensor_shape_range[0],
495 high=testGen.args.tensor_shape_range[1],
496 size=1,
497 )[0]
498 filter_shape = np.asarray([filter_oc, input_shape[1]])
499
500 bias_shape = np.asarray([filter_oc])
501
502 return [input_shape, filter_shape, bias_shape]
503
504 @staticmethod
505 def tgMatmul(testGen, op, rank, error_name=None):
506 pl, const = op["operands"]
507
508 if error_name != ErrorIf.WrongRank:
509 assert rank == 3
510 assert pl == 2 and const == 0
511
512 a_shape = testGen.makeShape(rank)
513
514 # Constrict the overall size of the shape when creating ERROR_IF tests
515 if error_name:
516 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
517
518 # Get a random number for b_oc even if target shape is defined
519 b_oc = np.int32(
520 testGen.rng.integers(
521 low=testGen.args.tensor_shape_range[0],
522 high=testGen.args.tensor_shape_range[1],
523 size=1,
524 )
525 )[0]
526 # If N or H is large let b_oc be 1 to reduce output tensor size
527 if max(a_shape) > 1000:
528 b_oc = 1
529
530 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
531 return [a_shape, b_shape]
532
533 @staticmethod
534 def tgConcat(testGen, opName, rank, error_name=None):
535 pl, const = opName["operands"]
536 shape = testGen.makeShape(rank)
537
538 # Create extra tensors to concat.
539 # Take into account value of pl when getting maximum number of concats
540 num_tensors = testGen.randInt(0, 4)
541 shape_list = []
542 for i in range(pl + const + num_tensors):
543 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
544 remove = testGen.rng.choice([True, False])
545 wrongShape = shape.copy()
546
547 if remove and len(shape) > 1:
548 wrongShape = wrongShape[1:]
549 else:
550 wrongShape = list(wrongShape)
551 wrongShape.append(testGen.rng.integers(1, 10))
552
553 shape_list.append(wrongShape)
554 else:
555 shape_list.append(shape.copy())
556
557 return shape_list
558
559 @staticmethod
560 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
561 if error_name in [
562 ErrorIf.AxisSmallerZero,
563 ErrorIf.AxisLargerRank,
564 ErrorIf.ConcatInputRankMismatch,
565 ]:
566 return shapeList
567
568 # Split concat shape along axis to allow for multiple const inputs
569 # without making too many large tensors
570 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
571 # If axis can't be split we still need to invalidate other dimensions
572 if error_name == ErrorIf.ConcatInputDimMismatch:
573 for shape in shapeList[1:]:
574 # Negative test shapeLists are created individually for each test,
575 # so no need to copy the shape before altering it.
576 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
577 return shapeList
578
579 # Create copy of shape we are going to split (so we don't alter shapeList)
580 shape = shapeList[0].copy()
581 # Add original shape as first input
582 new_shapeList = [shape.copy()]
583 length_on_axis = shape[axis]
584 remaining_length = length_on_axis
585 for i in range(len(shapeList) - 2):
586 # Calculate split on axis and remaining value
587 split_shape_val = int(shape[axis] / 2)
588 remaining_length = remaining_length - split_shape_val
589
590 # Append new shape, and set remaining shape
591 shape[axis] = split_shape_val
592 new_shapeList.append(shape.copy())
593
594 # invalidate dimensions
595 if error_name == ErrorIf.ConcatInputDimMismatch:
596 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
597 else:
598 shape[axis] = remaining_length
599
600 if i == len(shapeList) - 3:
601 new_shapeList.append(shape.copy())
602
603 return new_shapeList
604
605
606class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100607 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100608
609 def __init__(self):
610 pass
611
Jeremy Johnson1271c442023-09-05 11:39:26 +0100612 class TVGInfo:
613 """Enhanced tensor values information including data gen dict."""
614
615 def __init__(self, tensorList, dataGenDict):
616 self.tensorList = tensorList
617 self.dataGenDict = dataGenDict
618
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100619 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000620 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100621 pCount, cCount = op["operands"]
622
623 tens = []
624 tens.extend(
625 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
626 )
627 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
628
629 return tens
630
631 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100632 def tvgLazyGenDefault(
633 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
634 ):
635 # Variable inputs versus constants
636 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
637
638 overrideLazy = False
639 if not gtu.dtypeIsFloat(dtypeList[0]) and testGen.args.lazy_data_gen:
640 # TEMPORARY OVERRIDE for integer types
641 overrideLazy = True
642 testGen.args.lazy_data_gen = False
643
644 # TODO - Change to generation of data using library!
645 # For now - we fall back to original path (or when dealing with non-floats)
646 if not testGen.args.lazy_data_gen:
647 tens_ser_list = TosaTensorValuesGen.tvgDefault(
648 testGen,
649 testGen.TOSA_OP_LIST[opName],
650 dtypeList,
651 shapeList,
652 [],
653 error_name,
654 )
655 if overrideLazy:
656 # Return to lazy mode
657 testGen.args.lazy_data_gen = True
658 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
659
660 # Create data generator meta-data
661 dg_type = argsDict["dg_type"]
662 dg_tens_meta = {}
663 tens_ser_list = []
664 for idx, shape in enumerate(shapeList):
665
666 tens_meta = {}
667 tens_meta["generator"] = gtu.DataGenType(dg_type).name
668 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
669 tens_meta["shape"] = [int(i) for i in shape]
670 tens_meta["input_pos"] = idx
671 tens_meta["op"] = opName
672
673 if idx < pCount:
674 tens_meta["input_type"] = "variable"
675 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], None)
676 else:
677 tens_meta["input_type"] = "constant"
678 tens = testGen.ser.addConst(shape, dtypeList[idx], None)
679 tens_ser_list.append(tens)
680
681 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
682 info = {}
683 # TODO - generate seed for this generator based on test
684 info["rng_seed"] = -1
685 info["range"] = [
686 str(v)
687 for v in testGen.getDTypeRange(dtypeList[idx], high_inclusive=True)
688 ]
689 tens_meta["pseudo_random_info"] = info
690 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
691 info = {}
692 info["s"] = argsDict["s"]
693 info["ks"] = argsDict["ks"]
694 for key in gtu.DG_DOT_PRODUCT_OPTIONAL_INFO:
695 if key in argsDict:
696 if key.endswith("_type"):
697 info[key] = gtu.DTYPE_ATTRIBUTES[argsDict[key]]["json"]
698 else:
699 info[key] = argsDict[key]
700 tens_meta["dot_product_info"] = info
701 else:
702 # TODO - other data gen type
703 assert False, "TODO: support other data gen types"
704 dg_tens_meta[tens.name] = tens_meta
705
706 tens_data = {
707 "version": "0.1",
708 "tensors": dg_tens_meta,
709 }
710 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
711
712 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000713 def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100714 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100715 pCount, cCount = op["operands"]
716 assert (
717 pCount == 1 and cCount == 0
718 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100719 # Must create tensors with values within accumulator (int32) negatable
720 # range
721 max_val = (1 << 31) - 1
722 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100723 arr = np.int32(
724 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
725 )
726 placeholders = []
727 placeholders.append(
728 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
729 )
730 return placeholders
731 else:
732 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000733 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100734 )
735
736 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000737 def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100738 if dtypeList[0] == DType.INT32 and error_name is None:
739 # Make sure the operation does not cause value saturation - where
740 # the number wraps due to limited number of bits to store the answer
741 pCount, cCount = op["operands"]
742 assert (
743 pCount == 2 and cCount == 0
744 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
745 placeholders = []
746 add = op["op"] == Op.ADD
747 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
748 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
749 if add:
750 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
751 else:
752 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
753
754 # Work out the saturation limits
755 max_i32 = (1 << 31) - 1
756 min_i32 = -(1 << 31)
757 max_arr = np.full(shapeList[1], max_i32)
758 min_arr = np.full(shapeList[1], min_i32)
759
760 # Find how much values exceed the maximum/minimums
761 sat_max_arr = np.maximum(res_arr - max_arr, 0)
762 sat_min_arr = np.minimum(res_arr - min_arr, 0)
763
764 if not add:
765 # Swap saturation values and negate values as we need to perform opposite operations
766 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
767
768 # Create new array of unsaturated values by clipping values as needed
769 b_unsat_arr = b_arr
770 if (sat_max_arr != 0).any():
771 # Clip values that cause saturation
772 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
773 # Reduce axes in unsaturated tensor to match original tensor
774 for axis, dim in enumerate(b_arr.shape):
775 if dim != b_unsat_arr.shape[axis]:
776 assert (
777 dim == 1
778 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
779 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
780
781 if (sat_min_arr != 0).any():
782 # Clip values that cause saturation
783 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
784 # Reduce axes in unsaturated tensor to match original tensor
785 for axis, dim in enumerate(b_arr.shape):
786 if dim != b_unsat_arr.shape[axis]:
787 assert (
788 dim == 1
789 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
790 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
791
792 placeholders.append(
793 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
794 )
795 placeholders.append(
796 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
797 )
798
799 return placeholders
800 else:
801 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000802 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100803 )
804
805 @staticmethod
806 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000807 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100808 ):
809 if dtypeList[0] in (
810 DType.INT32,
811 DType.INT16,
812 DType.INT8,
813 ):
814 # Limit input tensors with cond_if_binary or while_loop to stop
815 # saturation of add/sub ops with int32 and keep all logical shift
816 # values between 0 to 31 for int16 or int8
817 pCount, cCount = op["operands"]
818 pRemain = pCount
819 placeholders = []
820 for idx, shape in enumerate(shapeList[:]):
821 if dtypeList[0] == DType.INT32:
822 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
823 else:
824 arr = np.int32(
825 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
826 )
827 if pRemain > 0:
828 placeholders.append(
829 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
830 )
831 pRemain -= 1
832 else:
833 placeholders.append(
834 testGen.ser.addConst(shape, dtypeList[idx], arr)
835 )
836
837 return placeholders
838 else:
839 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000840 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100841 )
842
843 @staticmethod
844 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000845 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100846 ):
847 pCount, cCount = op["operands"]
848 # Force value of operand[1] to be within [0, num_bits]
849 assert (
850 pCount == 2 and cCount == 0
851 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
852
853 placeholders = []
854 for idx, shape in enumerate(shapeList[:]):
855 if idx == 1:
856 if dtypeList[idx] == DType.INT8:
857 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
858 elif dtypeList[idx] == DType.INT16:
859 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
860 elif dtypeList[idx] == DType.INT32:
861 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
862 elif error_name == ErrorIf.WrongInputType:
863 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
864 else:
865 raise Exception("OpArithmeticRightShift: invalid input dtype")
866 else:
867 arr = testGen.getRandTensor(shape, dtypeList[idx])
868 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
869
870 return placeholders
871
872 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000873 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100874 # Set datatype of condition tensor to boolean
875 dtypeList[0] = DType.BOOL
876
877 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000878 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100879 )
880
881 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000882 def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100883 if error_name is None:
884 pCount, cCount = op["operands"]
885 assert (
886 pCount == 2 and cCount == 0
887 ), "Op.INTDIV must have 2 placeholders, 0 consts"
888
889 placeholders = []
890
891 # Two invalid cases for Op.INTDIV:
892 # 1. divisor == 0
893 # 2. dividend == -(1<<31) and divisor == -1
894 while True:
895 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
896 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
897
898 if (divisor_arr == 0).any():
899 continue
900
901 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
902 continue
903
904 break
905
906 placeholders.append(
907 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
908 )
909 placeholders.append(
910 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
911 )
912
913 return placeholders
914 else:
915 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000916 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100917 )
918
919 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000920 def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100921 if error_name is None:
922 pCount, cCount = op["operands"]
923 assert (
924 pCount == 2 and cCount == 0
925 ), "Op.MUL must have 2 placeholders, 0 consts"
926
927 tens = []
James Ward24dbc422022-10-19 12:20:31 +0100928 if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100929 tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
930 else:
931 placeholders = []
932
933 # Make sure multiply result in int32 range
934 shift = testArgs[0]
935 if dtypeList[0] == DType.INT8:
936 num_bits = 8
937 elif dtypeList[0] == DType.INT16:
938 num_bits = 16
939 elif dtypeList[0] == DType.INT32:
940 num_bits = 32
941 elif error_name == ErrorIf.WrongInputType:
942 num_bits = 8
943 else:
944 raise Exception("OpMul: invalid input dtype")
945
946 for idx, shape in enumerate(shapeList[:]):
947 low = -(2 ** (num_bits - 1))
948 high = (2 ** (num_bits - 1)) - 1
949
950 a_arr = np.int32(
951 testGen.rng.integers(low=low, high=high, size=shapeList[0])
952 )
953 b_arr = np.int32(
954 testGen.rng.integers(low=low, high=high, size=shapeList[1])
955 )
956
957 i = 0
958 while True:
959
960 a_arr_64 = a_arr.astype(np.int64)
961 b_arr_64 = b_arr.astype(np.int64)
962
963 if shift > 0:
964 rounding = 1 << (shift - 1)
965 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
966 else:
967 result_arr = a_arr_64 * b_arr_64
968
969 if (result_arr > -(2**31)).all() and (
970 result_arr <= ((2**31) - 1)
971 ).all():
972 break
973
974 i = i + 1
975 a_arr = a_arr // 2
976 b_arr = b_arr // 2
977
978 placeholders.append(
979 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
980 )
981 placeholders.append(
982 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
983 )
984
985 tens.extend(placeholders)
986
987 return tens
988 else:
989 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000990 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100991 )
992
993 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000994 def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100995 count = len(shapeList) - testGen.args.num_const_inputs_concat
996 if count < 1:
997 count = 1
998 if testGen.args.num_const_inputs_concat == 0:
999 count = len(shapeList)
1000
1001 # Ensure axis is an int
1002 testArgs[0] = int(testArgs[0])
1003
1004 shapeList = TosaTensorGen.tgConcatConstInput(
1005 testGen, shapeList, testArgs[0], error_name
1006 )
1007
1008 tens = []
1009 tens.extend(
1010 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
1011 )
1012 tens.extend(testGen.buildConstTensors(shapeList[count:], dtypeList[count:]))
1013
1014 return tens
1015
1016 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001017 def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001018 pCount, cCount = op["operands"]
1019 assert (
1020 pCount == 2 and cCount == 0
1021 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
1022 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1023 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
1024 placeholders = []
1025 placeholders.append(
1026 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1027 )
1028 placeholders.append(
1029 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1030 )
1031
1032 return placeholders
1033
1034 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001035 def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001036 if error_name is None:
1037 pCount, cCount = op["operands"]
1038 assert (
1039 pCount == 2 and cCount == 0
1040 ), "Op.EQUAL must have 2 placeholders, 0 consts"
1041 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1042 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
1043 # Using random numbers means that it will be very unlikely that
1044 # there are any matching (equal) values, therefore force that
1045 # there are twice the number of matching values as the tensor rank
1046 for num in range(0, len(shapeList[0]) * 2):
1047 a_index = []
1048 b_index = []
1049 # Choose an index in each axis for the whole shape
1050 for axis in range(0, len(shapeList[0])):
1051 # Index can be up to the largest dimension in both shapes
1052 index = np.int32(
1053 testGen.rng.integers(
1054 0, max(shapeList[0][axis], shapeList[1][axis])
1055 )
1056 )
1057 # Reduce the index down to a shape's dim for broadcasting
1058 a_index.append(min(shapeList[0][axis] - 1, index))
1059 b_index.append(min(shapeList[1][axis] - 1, index))
1060
1061 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1062
1063 placeholders = []
1064 placeholders.append(
1065 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1066 )
1067 placeholders.append(
1068 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1069 )
1070 return placeholders
1071 else:
1072 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001073 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001074 )
1075
1076 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001077 def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001078 if dtypeList[0] == DType.INT32:
1079 pCount, cCount = op["operands"]
1080 assert (
1081 pCount == 1 and cCount == 0
1082 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1083 # Limit values so that the sum cannot exceed the range of an int32 during
1084 # summation of any axis
1085 range_val = int((1 << 31) / max(shapeList[0]))
1086 values_arr = np.int32(
1087 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1088 )
1089 placeholders = []
1090 placeholders.append(
1091 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1092 )
1093 return placeholders
1094 else:
1095 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001096 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001097 )
1098
1099
1100class TosaArgGen:
1101 """Argument generators create exhaustive or random lists of attributes for
1102 operators that take attributes or other parameters.
1103
1104 The return value is a list of (descriptive_name, [arglist]) tuples where
1105 the descriptive_name is appended to the test name and the arglist is expanded
1106 as arguments to the operator build function.
1107 """
1108
1109 def __init__(self):
1110 pass
1111
1112 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +01001113 def _add_data_generators(testGen, opName, dtype, arg_list, error_name, **kwargs):
1114 """Add extra tests for each type of data generator for this op."""
1115 if error_name is None and "data_gen" in testGen.TOSA_OP_LIST[opName]:
1116 if dtype in [DType.FP16, DType.FP32, DType.BF16]:
1117 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1118 else:
1119 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1120 else:
1121 # Error test or No data generator types listed - assume random
1122 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1123
1124 # Expand arg list with other data generator types
1125 new_arg_list = []
1126 for dg_type in dataGenTypesList:
1127 for arg_str, arg_attrs in arg_list:
1128 arg_dict = arg_attrs[0]
1129 arg_dict["dg_type"] = dg_type
1130
1131 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
1132 # Default test
1133 new_arg_list.append((arg_str, [arg_dict]))
1134
1135 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1136 # Extra tests for each dot product test set
1137 dot_products = kwargs["dot_products"]
1138 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
1139 print(
Jeremy Johnson51779fd2023-09-12 10:27:43 +01001140 f"Skipping {opName} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001141 )
1142 continue
1143 arg_dict["ks"] = kwargs["ks"]
1144 for key in gtu.DG_DOT_PRODUCT_OPTIONAL_INFO:
1145 if key in kwargs:
1146 arg_dict[key] = kwargs[key]
1147
1148 for s in testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS:
1149 new_arg_str = f"{arg_str}_s{s}"
1150 new_arg_dict = arg_dict.copy()
1151 new_arg_dict["s"] = s
1152 new_arg_list.append((new_arg_str, [new_arg_dict]))
1153
1154 return new_arg_list
1155
1156 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001157 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1158 """A trivial argument generator for operators that don't take any
1159 non-tensor arguments"""
1160 return [("", [])]
1161
1162 @staticmethod
1163 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1164 """Build the axis argument for operators that take a single axis"""
1165 axes = []
1166 shape = shapeList[0]
1167
1168 if error_name == ErrorIf.AxisSmallerZero:
1169 small_axis = testGen.rng.integers(-5, 0)
1170 axes.append(("axis{}".format(small_axis), [small_axis]))
1171 elif error_name == ErrorIf.AxisLargerRank:
1172 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
1173 axes.append(("axis{}".format(large_axis), [large_axis]))
1174 else:
1175 for a in range(0, len(shape)):
1176 axes.append(("axis{}".format(a), [a]))
1177
1178 return axes
1179
1180 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001181 def _calculate_sparsity(num_tests, sparsity_factor):
1182 sparsity = num_tests // sparsity_factor + 1
1183 # If there are only a small number of tests, just select them all
1184 if sparsity < 13:
1185 sparsity = 1
1186 # To get a variety of parameter combinations sparsity should not be a
1187 # multiple of 2, 3 or 5
1188 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1189 sparsity += 1
1190 return sparsity
1191
1192 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001193 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001194 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001195 arg_list = []
1196
Jeremy Johnson0c716862023-04-13 17:18:19 +01001197 if testGen.args.level8k and error_name is not None:
1198 # Don't produce negative large tests
1199 return arg_list
1200
1201 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001202 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001203 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001204 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001205
Jeremy Johnson1271c442023-09-05 11:39:26 +01001206 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001207
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001208 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001209 conv3d = opName.startswith("conv3d")
1210 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001211 if error_name != ErrorIf.WrongRank:
1212 assert len(ifm_shape) == rank
1213 assert len(filter_shape) == rank
1214
Jeremy Johnson0c716862023-04-13 17:18:19 +01001215 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001216 k_rank = rank - 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001217 k_pos = 0 if opName.startswith("depthwise") else 1
1218 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001219
Jeremy Johnson0c716862023-04-13 17:18:19 +01001220 if not testGen.args.level8k:
1221 # Generate comprehensive argument lists
1222 # - except for named errors, which use specific invalid value(s)
1223 if error_name == ErrorIf.PadSmallerZero:
1224 p_vals = [testGen.rng.choice(range(-5, 0))]
1225 else:
1226 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1227 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1228 if error_name == ErrorIf.StrideSmallerOne:
1229 # Can't use stride=0, as it is used to derive output shape, as a divisor
1230 s_vals = [testGen.rng.choice(range(-5, 0))]
1231 else:
1232 # Stride must be greater than 1 to force non-integer error
1233 startStride = (
1234 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001235 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001236 s_vals = [
1237 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1238 ]
1239 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1240 if error_name == ErrorIf.DilationSmallerOne:
1241 d_vals = [testGen.rng.choice(range(-5, 1))]
1242 else:
1243 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1244 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001245
Jeremy Johnson0c716862023-04-13 17:18:19 +01001246 if not error_name and testGen.args.oversize:
1247 # add some oversize argument values
1248 if max(ifm_shape) < 64:
1249 bigPadding = 9
1250 paddings.update(
1251 {
1252 x
1253 for x in itertools.product(
1254 *([[0, bigPadding]] * (k_rank * 2))
1255 )
1256 }
1257 )
1258 bigStride = 8
1259 strides.update(
1260 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1261 )
1262 bigDilation = 7
1263 dilations.update(
1264 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1265 )
1266 max_dim_size = None
1267
1268 # There are too many parameter combinations, so generate them sparsely,
1269 # very sparse for negative tests
1270 sparsity_factor = 2 if error_name else 120
1271 sparsity = TosaArgGen._calculate_sparsity(
1272 len(paddings) * len(strides) * len(dilations), sparsity_factor
1273 )
1274 else:
1275 # Only test 8k levels boundaries
1276 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1277 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1278 bigPadding = bigKernel
1279
1280 dilation_shape = [1] * k_rank
1281 pad_shape = [0] * k_rank * 2
1282 if conv3d:
1283 # Small stride apart from for big kernel (see below) to keep
1284 # tensor size/calculation small
1285 stride_shape = [1] * k_rank
1286 for idx in range(k_rank):
1287 pad_offset = idx * 2
1288 if k_shape[idx] == bigKernel:
1289 # Padding shape needs to account for tensor shape
1290 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1291 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1292 # Big stride to reduce output size
1293 stride_shape[idx] = bigKernel
1294 else:
1295 # Account for kernel size
1296 pad_shape[pad_offset] = k_shape[idx] - 1
1297 else:
1298 # Always have a large stride with extra padding and dilation to keep
1299 # tensor calculation reasonable
1300 stride_shape = [bigKernel] * k_rank
1301 for idx in range(k_rank):
1302 # Dilation shape must account for kernel size
1303 dilation_shape[idx] = bigKernel // k_shape[idx]
1304 # Padding shape needs to accommodate tensor/kernel & dilation
1305 pad_offset = idx * 2
1306 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1307 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1308
1309 strides = {tuple(stride_shape)}
1310 dilations = {tuple(dilation_shape)}
1311 paddings = {tuple(pad_shape)}
1312 # Create a limit for the output dimensions size
1313 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1314
1315 # Currently allow all combinations that are reasonable size
1316 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001317
1318 n = 0
1319 for s in sorted(list(strides)):
1320 for p in sorted(list(paddings)):
1321 for d in sorted(list(dilations)):
1322 if (
1323 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001324 # the padded shape must exceed the dilation * kernel to get a positive
1325 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001326 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
1327 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01001328 and (
1329 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01001330 or (
1331 (ifm_shape[3] - 1 + p[4] + p[5])
1332 > d[2] * (k_shape[2] - 1)
1333 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01001334 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001335 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001336 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01001337 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001338 for index in range(k_rank):
1339 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001340 partial = (
1341 ifm_shape[index + 1]
1342 - 1
1343 + p[pad_offset]
1344 + p[pad_offset + 1]
1345 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001346 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001347 remainders.append(partial % s[index])
1348 outputs.append((partial // s[index]) + 1)
1349
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001350 if (
1351 # the parameters must produce integer exact output
1352 error_name != ErrorIf.ConvOutputShapeNonInteger
1353 and max(remainders) == 0
1354 ) or (
1355 error_name == ErrorIf.ConvOutputShapeNonInteger
1356 and max(remainders) > 0
1357 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001358 if (
1359 max_dim_size is not None
1360 and max(outputs) >= max_dim_size
1361 ):
1362 # Test will consume too much memory - skip it
1363 continue
1364
1365 # Support for larger values than 9 needs different delimiter
1366 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001367 arg_list.append(
1368 (
James Ward8b390432022-08-12 20:48:56 +01001369 "acc{}_st{}_pad{}_dilat{}".format(
1370 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001371 delim.join([str(x) for x in s]),
1372 delim.join([str(x) for x in p]),
1373 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001374 ),
James Ward8b390432022-08-12 20:48:56 +01001375 [accum_dtype, s, p, d],
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001376 )
1377 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001378 n += 1
1379
1380 return arg_list
1381
1382 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001383 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1384
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001385 assert isinstance(dtypes, list) or isinstance(
1386 dtypes, tuple
1387 ), f"{dtypes} unexpected"
1388 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01001389
1390 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001391 accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01001392 elif error_name == ErrorIf.WrongInputType:
1393 # Pick some potentially correct output dtype if input type is incorrect
1394 accum_dtype = DType.INT32
1395 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001396 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001397
1398 return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])]
1399
1400 @staticmethod
1401 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
1402 # Get valid accumulate type(s)
1403 if dtype == DType.INT8:
1404 accum_dtypes = [DType.INT32]
1405 elif dtype == DType.INT16:
1406 accum_dtypes = [DType.INT48]
1407 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001408 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001409 elif dtype == DType.BF16:
1410 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001411 elif dtype == DType.FP32:
1412 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001413 elif error_name is None:
1414 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
1415
1416 if error_name == ErrorIf.WrongOutputType:
1417 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson1271c442023-09-05 11:39:26 +01001418 accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01001419 elif error_name == ErrorIf.WrongInputType:
1420 # Pick some potentially correct output dtype if input type is incorrect
1421 accum_dtypes = [DType.INT32]
1422
Jeremy Johnson1271c442023-09-05 11:39:26 +01001423 arg_list = [
1424 (f"acc{testGen.typeStr(a)}", [{"acc_type": a}]) for a in accum_dtypes
1425 ]
1426
1427 arg_list = TosaArgGen._add_data_generators(
1428 testGen,
1429 opName,
1430 dtype,
1431 arg_list,
1432 error_name,
1433 ks=int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
1434 # Set dot_products = N*H*W
1435 dot_products=gtu.product(
1436 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
1437 ),
1438 )
1439 return arg_list
James Ward8b390432022-08-12 20:48:56 +01001440
1441 @staticmethod
1442 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001443 arg_list = []
1444
Jeremy Johnson0c716862023-04-13 17:18:19 +01001445 if testGen.args.level8k and error_name is not None:
1446 # Don't produce negative large tests
1447 return arg_list
1448
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001449 ifm_shape = shapeList[0]
1450 filter_shape = shapeList[1]
1451
Jeremy Johnson1271c442023-09-05 11:39:26 +01001452 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001453
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001454 # Must be rank 4
1455 if error_name != ErrorIf.WrongRank:
1456 assert len(ifm_shape) == 4
1457 assert len(filter_shape) == 4
1458
Jeremy Johnson0c716862023-04-13 17:18:19 +01001459 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001460
Jeremy Johnson0c716862023-04-13 17:18:19 +01001461 if not testGen.args.level8k:
1462 # Generate comprehensive argument lists
1463 # - except for named errors, which use specific invalid value(s)
1464 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
1465 if error_name == ErrorIf.PadLargerEqualKernel:
1466 max_filter_size = -max(k_shape[0], k_shape[1])
1467 p_vals = [
1468 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
1469 ]
1470 else:
1471 p_vals = [
1472 x
1473 for x in range(
1474 smallest_padding_size, testGen.args.max_conv_padding + 1
1475 )
1476 ]
1477 paddings = {x for x in itertools.product(*([p_vals] * 4))}
1478 if error_name == ErrorIf.StrideSmallerOne:
1479 # Can't use stride=0, as it is used to derive output shape, as a divisor
1480 s_vals = [testGen.rng.choice(range(-5, 0))]
1481 else:
1482 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
1483 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001484
Jeremy Johnson0c716862023-04-13 17:18:19 +01001485 if not error_name and testGen.args.oversize:
1486 # add some oversize argument values
1487 if max(ifm_shape) < 64:
1488 bigPadding = 9
1489 paddings.update(
1490 {
1491 x
1492 for x in itertools.product(
1493 *([[smallest_padding_size, bigPadding]] * 4)
1494 )
1495 }
1496 )
1497 bigStride = 8
1498 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
1499
1500 # There are too many parameter combinations, so generate them sparsely,
1501 # very sparse for negative tests
1502 sparsity_factor = 2 if error_name else 10
1503 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
1504 # If there are only a small number of tests, just select them all
1505 if sparsity < 13:
1506 sparsity = 1
1507 # To get a variety of parameter combinations sparsity should not be a
1508 # multiple of 2, 3 or 5
1509 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1510 sparsity += 1
1511 else:
1512 # Only test 8k levels boundaries
1513 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1514 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1515 bigPadding = bigKernel
1516
1517 pad_shape = [0] * (len(k_shape) * 2)
1518 stride_shape = [1] * len(k_shape)
1519 # The point at which input dimension combined with the stride will
1520 # create large output sizes!
1521 LARGE_SIZE = 2
1522 for idx in range(len(k_shape)):
1523 pad_offset = idx * 2
1524 if k_shape[idx] == bigKernel:
1525 # Set large stride
1526 stride_shape[idx] = bigKernel
1527 # Use negative output padding to reduce shape size
1528 pad_shape[pad_offset] = -(bigPadding - 1)
1529 if ifm_shape[idx + 1] > LARGE_SIZE:
1530 pad_shape[pad_offset + 1] = -(bigPadding - 1)
1531 else:
1532 # The other dimension should be the bigKernel
1533 alt_idx = 1 - idx
1534 if (
1535 k_shape[alt_idx] == bigKernel
1536 and ifm_shape[alt_idx + 1] < LARGE_SIZE
1537 ):
1538 # As the input is small, the large stride won't
1539 # affect the output so we can add some padding
1540 pad_shape[pad_offset + 1] = bigPadding
1541
1542 strides = {tuple(stride_shape)}
1543 paddings = {tuple(pad_shape)}
1544
1545 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001546 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001547
1548 n = 0
1549 for s in sorted(list(strides)):
1550 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07001551 if n % sparsity == 0:
1552 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001553 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
1554 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07001555 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001556
1557 # Support for larger values than 9 needs different delimiter
1558 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07001559 arg_list.append(
1560 (
James Ward8b390432022-08-12 20:48:56 +01001561 "acc{}_st{}_pad{}_os{}".format(
1562 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001563 delim.join([str(x) for x in s]),
1564 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07001565 "x".join([str(x) for x in os]),
1566 ),
James Ward8b390432022-08-12 20:48:56 +01001567 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001568 )
TatWai Chong24594f52022-06-08 00:48:04 -07001569 )
1570 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001571
1572 return arg_list
1573
1574 @staticmethod
1575 def agPad(testGen, opName, shapeList, dtype, error_name=None):
1576 arg_list = []
1577 rank = len(shapeList[0])
1578
1579 # Exhaustively test combinations of padding on each side of each dimension
1580 # - the range of padding values is defined by pad_min and pad_max
1581 # - for padding >9, the name format needs to be more distinctive
1582 pad_min, pad_max = 0, 1
1583 pad_values = [x for x in range(pad_min, pad_max + 1)]
1584 if error_name == ErrorIf.PadSmallerZero:
1585 pad_values = [x for x in range(-2, 0)]
1586 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
1587 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
1588
1589 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
1590 pad_const_int = testGen.getRandNumberDType(dtype)
1591 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00001592 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001593 pad_const_int = 0
1594 pad_const_fp = testGen.getRandNumberDType(dtype)
1595 else:
1596 return []
1597
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001598 list_shape_pad_values = list(shape_pad_values)
1599 # If we are producing tests for rank 6 or greater use sparsity
1600 if len(list_shape_pad_values) > 1024:
1601 sparsity_factor = 2 if error_name else 120
1602 sparsity = TosaArgGen._calculate_sparsity(
1603 len(list_shape_pad_values), sparsity_factor
1604 )
1605 else:
1606 sparsity = 1
1607
1608 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01001609 paddings = list(paddings)
1610 args_valid = True
1611
1612 if error_name == ErrorIf.PadSmallerZero:
1613 # Prevent negative output shapes while ensuring still testing for negative padding
1614 for i in range(rank):
1615 dim_after_padding = (
1616 paddings[i][0] + paddings[i][1] + shapeList[0][i]
1617 )
1618 if dim_after_padding < 1:
1619 paddings[i] = (0, 0)
1620 if all([p > -1 for p in paddings[i]]):
1621 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001622 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01001623 name = "pad"
1624 for r in range(rank):
1625 before, after = paddings[r]
1626 name = f"{name}{before}{after}"
1627 arg_list.append(
1628 (name, [np.array(paddings), pad_const_int, pad_const_fp])
1629 )
1630
1631 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
1632 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001633
1634 return arg_list
1635
1636 @staticmethod
1637 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
1638 arg_list = []
1639
1640 shape = shapeList[0]
1641 if error_name != ErrorIf.WrongRank:
1642 assert len(shape) == 4
1643
Jeremy Johnson0c716862023-04-13 17:18:19 +01001644 test_level8k = testGen.args.level8k and error_name is None
1645
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001646 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001647 startKernel = 2
1648 startPad = 0
1649 if not test_level8k:
1650 # Generate comprehensive argument lists
1651 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
1652 paddings = {x for x in itertools.product(*([p_vals] * 4))}
1653 # Stride must be greater than 1 to force non-integer error
1654 s_vals = [
1655 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
1656 ]
1657 strides = {x for x in itertools.product(*([s_vals] * 2))}
1658 k_vals = [
1659 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
1660 ]
1661 kernels = {x for x in itertools.product(*([k_vals] * 2))}
1662 max_dim_size = None
1663 else:
1664 # Only test 8k levels
1665 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1666 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1667 strides = {(1, bigStride), (bigStride, 4)}
1668 kernels = {(1, bigKernel), (bigKernel, 3)}
1669 paddings = set()
1670 for s in sorted(list(strides)):
1671 for k in sorted(list(kernels)):
1672 padding = []
1673 for idx in range(len(k)):
1674 total_padding = s[idx] - shape[idx + 1] + k[idx]
1675 while total_padding < 0:
1676 # Must meet: shape + padding > kernel
1677 total_padding += s[idx]
1678 if total_padding < k[idx]:
1679 padding.extend([0, total_padding])
1680 else:
1681 # Note this may produce padding >= k[idx] which is not
1682 # allowed - but will be ignored in the creation loop below
1683 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
1684 paddings.add(tuple(padding))
1685 # Create a limit for the output dimensions size
1686 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001687
James Ward8b390432022-08-12 20:48:56 +01001688 if opName == "max_pool2d":
1689 accum_dtypes = [None] # max_pool has no accumulate dtype
1690 elif dtype == DType.INT8 or dtype == DType.INT16:
1691 accum_dtypes = [DType.INT32]
1692 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001693 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001694 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001695 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001696 elif error_name is None:
1697 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
1698 else:
1699 # Set to something for the ErrorIf case which has
1700 # incorrect input data-type
1701 accum_dtypes = [DType.INT32]
1702
Jeremy Johnson0c716862023-04-13 17:18:19 +01001703 if not test_level8k:
1704 if testGen.args.oversize:
1705 # add some oversize argument values
1706 bigStride = 7
1707 bigKernel = 9
1708 strides.update(
1709 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001710 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001711 kernels.update(
1712 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
1713 )
1714 if max(shape) < 64:
1715 # padding must be less than the kernel size
1716 bigPadding = bigKernel - 1
1717 paddings.update(
1718 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
1719 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001720
Jeremy Johnson0c716862023-04-13 17:18:19 +01001721 # There are too many parameter combinations, so generate them sparsely,
1722 # very sparse for negative tests
1723 sparsity_factor = 2 if error_name else 500
1724 sparsity = (
1725 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
1726 )
1727 else:
1728 # We have already limited test output combinations for 8k tests
1729 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001730
James Ward8b390432022-08-12 20:48:56 +01001731 arg_str = (
1732 "acc{}_st{}_kern{}_pad{}"
1733 if accum_dtypes[0] is not None
1734 else "st{}_kern{}_pad{}"
1735 )
1736
1737 def get_arg_list_element(accum, stride, pad, kern):
1738 # Return tuple containing the formatted argument string and
1739 # the corresponding argument values
Jeremy Johnson0c716862023-04-13 17:18:19 +01001740
1741 # Support for larger values than 9 needs different delimiter
1742 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01001743 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01001744 delim.join([str(x) for x in stride]),
1745 delim.join([str(x) for x in kern]),
1746 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01001747 ]
1748 # Note: different order to string
1749 arg_val_elems = [stride, pad, kern]
1750
1751 if accum is not None:
1752 arg_str_elems.insert(0, testGen.typeStr(accum))
1753 arg_val_elems.insert(0, accum)
1754 return (arg_str.format(*arg_str_elems), arg_val_elems)
1755
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001756 n = 0
James Ward8b390432022-08-12 20:48:56 +01001757 for a in accum_dtypes:
1758 for s in sorted(list(strides)):
1759 for p in sorted(list(paddings)):
1760 for k in sorted(list(kernels)):
1761 if error_name in [
1762 ErrorIf.StrideSmallerOne,
1763 ErrorIf.KernelSmallerOne,
1764 ErrorIf.PadSmallerZero,
1765 ErrorIf.PadLargerEqualKernel,
1766 ]:
1767 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
1768 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001769 )
James Ward8b390432022-08-12 20:48:56 +01001770 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
1771 arg_vals = [a, sNew, pNew, kNew]
1772 arg_list.append(get_arg_list_element(*arg_vals))
1773 elif (
1774 n % sparsity == 0
1775 # padding must not exceed the kernel size
1776 and p[0] < k[0]
1777 and p[1] < k[0]
1778 and p[2] < k[1]
1779 and p[3] < k[1]
1780 # the padded shape must exceed the kernel size
1781 and (shape[1] + p[0] + p[1]) > k[0]
1782 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001783 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001784 partial_h = shape[1] + p[0] + p[1] - k[0]
1785 partial_w = shape[2] + p[2] + p[3] - k[1]
1786 remainder_h = partial_h % s[0]
1787 remainder_w = partial_w % s[1]
1788 output_h = partial_h // s[0] + 1
1789 output_w = partial_w // s[1] + 1
1790 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01001791 if (
1792 # the parameters must produce integer exact output
1793 error_name != ErrorIf.PoolingOutputShapeNonInteger
1794 and remainder_h == 0
1795 and remainder_w == 0
1796 ) or (
1797 error_name == ErrorIf.PoolingOutputShapeNonInteger
1798 and (remainder_h != 0 or remainder_w != 0)
1799 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001800 if (
1801 max_dim_size is not None
1802 and max(output_h, output_w) > max_dim_size
1803 ):
1804 # Test will consume too much memory - skip it
1805 continue
James Ward8b390432022-08-12 20:48:56 +01001806 arg_vals = [a, s, p, k]
1807 arg_list.append(get_arg_list_element(*arg_vals))
1808 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001809
1810 return arg_list
1811
1812 @staticmethod
1813 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
1814 arg_list = []
1815
1816 # Enumerate the output types here
1817 if error_name == ErrorIf.WrongOutputType:
1818 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
1819 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00001820 dtypeList = [
1821 DType.BOOL,
1822 DType.INT16,
1823 DType.INT32,
1824 DType.FP16,
1825 DType.BF16,
1826 DType.FP32,
1827 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001828 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00001829 dtypeList = [
1830 DType.BOOL,
1831 DType.INT8,
1832 DType.INT32,
1833 DType.FP16,
1834 DType.BF16,
1835 DType.FP32,
1836 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001837 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00001838 dtypeList = [
1839 DType.BOOL,
1840 DType.INT8,
1841 DType.INT16,
1842 DType.FP16,
1843 DType.BF16,
1844 DType.FP32,
1845 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001846 elif inDtype == DType.BOOL:
1847 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01001848 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00001849 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001850 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00001851 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001852 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00001853 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001854 elif error_name == ErrorIf.WrongInputType:
1855 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001856 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001857 else:
1858 raise Exception("Unexpected input dtype: {}".format(inDtype))
1859
1860 for dtype in dtypeList:
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001861 arg_list.append(("out{}".format(testGen.typeStr(dtype)), [dtype]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001862
1863 return arg_list
1864
1865 @staticmethod
1866 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
1867 arg_list = []
1868
1869 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001870 for outDtype in [
1871 DType.UINT8,
1872 DType.INT8,
1873 DType.INT16,
1874 DType.INT32,
1875 DType.UINT16,
1876 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001877 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001878 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001879 and error_name == ErrorIf.OutputZeroPointNotZero
1880 ):
1881 continue
1882 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001883 outDtype != DType.UINT16
1884 and error_name == ErrorIf.U16OutputZeroPointNotValid
1885 ) or (
1886 inDtype != DType.UINT16
1887 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001888 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001889 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001890 continue
1891 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001892 inDtype == DType.UINT8
1893 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001894 and error_name != ErrorIf.WrongOutputType
1895 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001896 # The only output dtypes for UINT8 are INT8/INT16, skip all others
1897 continue
1898 if (
1899 inDtype not in [DType.INT8, DType.INT16]
1900 and outDtype == DType.UINT8
1901 and error_name != ErrorIf.WrongOutputType
1902 ):
1903 # The only input dtypes for UINT8 are INT8/INT16, skip all others
1904 continue
1905 if (
1906 inDtype == DType.UINT16
1907 and outDtype != DType.INT16
1908 and error_name != ErrorIf.WrongOutputType
1909 ):
1910 # The only output dtype for UINT16 is INT16, skip all others
1911 continue
1912 if (
1913 inDtype != DType.INT16
1914 and outDtype == DType.UINT16
1915 and error_name != ErrorIf.WrongOutputType
1916 ):
1917 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001918 continue
1919 if (
1920 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001921 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001922 ):
1923 continue
1924
1925 for scale32 in [False, True]:
1926 if error_name == ErrorIf.ScaleTrue and not scale32:
1927 continue
1928 elif error_name == ErrorIf.ScaleNotTrue and scale32:
1929 continue
1930 for double_round in [False, True]:
1931 if error_name == ErrorIf.ScaleNotTrue and not double_round:
1932 continue
1933 for per_channel in [False, True]:
1934
1935 if (
1936 inDtype == DType.INT48
1937 and scale32
1938 and error_name != ErrorIf.ScaleTrue
1939 ):
1940 # Illegal condition. Must be scale32=False
1941 continue
1942 if (
1943 double_round
1944 and not scale32
1945 and error_name != ErrorIf.ScaleNotTrue
1946 ):
1947 # Illegal condition. ERROR_IF(!scale32 && double_round)
1948 continue
1949
1950 arg_list.append(
1951 (
1952 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001953 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001954 int(scale32),
1955 int(double_round),
1956 int(per_channel),
1957 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001958 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001959 )
1960 )
1961
1962 return arg_list
1963
1964 @staticmethod
1965 def agMul(testGen, opName, shapeList, dtype, error_name=None):
1966 arg_list = []
1967
1968 if dtype is DType.INT32:
1969 for p in range(testGen.args.num_rand_permutations):
1970
1971 shift = testGen.randInt(0, 32)
1972
1973 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
1974 else:
1975 arg_list.append(("perm0_shift0", [0]))
1976
1977 return arg_list
1978
1979 @staticmethod
1980 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
1981 arg_list = []
1982
1983 arg_list.append(("roundTrue", [True]))
1984 arg_list.append(("roundFalse", [False]))
1985
1986 return arg_list
1987
Luke Hutton57287132023-02-06 14:54:18 +00001988 @staticmethod
1989 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
1990 arg_list = []
1991
1992 arg_list.append(("inverseTrue", [True]))
1993 arg_list.append(("inverseFalse", [False]))
1994
1995 return arg_list
1996
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001997 # Helper function for reshape. Gets some factors of a larger number.
1998 @staticmethod
1999 def getFactors(val, start=1):
2000 factors = []
2001
2002 for i in range(start, int(np.sqrt(val)) + 1):
2003 if (val % i) == 0:
2004 factors.append(i)
2005
2006 return factors
2007
2008 @staticmethod
2009 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
2010 arg_list = []
2011
2012 origShape = shapeList[0]
2013
2014 totalElements = 1
2015 for s in origShape:
2016 totalElements *= s
2017
2018 # This code is NOT fast. Fortunately, the numbers are fairly small.
2019 factors = TosaArgGen.getFactors(totalElements)
2020
2021 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002022 # Rank from 1 to TOSA_TENSOR_MAX_RANK
2023 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002024 if len(factors) < newRank:
2025 continue
2026
2027 found = True
2028 # escape_counter breaks while loop if it continues on for too long
2029 escape_counter = 0
2030 while found:
2031 newShape = []
Jerry Ge264f7fa2023-04-21 22:49:57 +00002032 new_shape_inferred = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002033 # Generate newShape ensuring it isn't a duplicate
2034 remainingElements = totalElements
2035 shuffledFactors = testGen.rng.permutation(factors)
Jerry Ge264f7fa2023-04-21 22:49:57 +00002036 inferred_dim = testGen.rng.integers(1, newRank + 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002037 for i in range(1, newRank):
2038 # pick rank-1 factors
2039 newShape.append(shuffledFactors[0])
2040 remainingElements = remainingElements // shuffledFactors[0]
Jerry Ge264f7fa2023-04-21 22:49:57 +00002041 if i == inferred_dim:
2042 new_shape_inferred.append(-1)
2043 else:
2044 new_shape_inferred.append(shuffledFactors[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002045 shuffledFactors = testGen.rng.permutation(
2046 TosaArgGen.getFactors(remainingElements)
2047 )
2048 newShape.append(remainingElements)
Jerry Ge264f7fa2023-04-21 22:49:57 +00002049 if inferred_dim == newRank:
2050 new_shape_inferred.append(-1)
2051 else:
2052 new_shape_inferred.append(remainingElements)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002053
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002054 # Check for duplicates
2055 found = False
2056 for name, other_shape in arg_list:
2057 if other_shape[0] == newShape:
2058 found = True
2059 break
2060
2061 escape_counter += 1
2062 if escape_counter >= 100:
2063 break
2064
2065 if not found:
Jerry Ge264f7fa2023-04-21 22:49:57 +00002066 if error_name in [
2067 ErrorIf.ReshapeOutputSizeNonInteger,
2068 ErrorIf.ReshapeOutputSizeMultiInference,
2069 ]:
2070 if newRank < 2:
2071 # Need at least two dimensions
2072 continue
2073 # NOTE: Change inferred_dim starting offset from 1 to 0
2074 inferred_dim -= 1
2075 extra_dim = inferred_dim + testGen.rng.integers(1, newRank)
2076 extra_dim = extra_dim % newRank
2077 assert extra_dim != inferred_dim
2078 if error_name == ErrorIf.ReshapeOutputSizeNonInteger:
2079 elements = 1
2080 for i, dim_value in enumerate(new_shape_inferred):
2081 if i != inferred_dim and i != extra_dim:
2082 elements *= dim_value
2083 dim_value = new_shape_inferred[extra_dim]
2084 while totalElements % (elements * dim_value) == 0:
2085 dim_value += 1
2086 new_shape_inferred[extra_dim] = dim_value
2087 else:
2088 assert error_name == ErrorIf.ReshapeOutputSizeMultiInference
2089 new_shape_inferred[extra_dim] = -1
2090 else:
2091 arg_list.append(
2092 ("perm{}_rank{}_outdefined".format(p, newRank), [newShape])
2093 )
2094 if error_name != ErrorIf.TensorSizeInputOutputMismatch:
2095 arg_list.append(
2096 (
2097 "perm{}_rank{}_outinferred".format(p, newRank),
2098 [new_shape_inferred],
2099 )
2100 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002101
2102 return arg_list
2103
2104 @staticmethod
2105 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
2106 arg_list = []
2107
2108 ifm_shape = shapeList[0]
2109
2110 if error_name == ErrorIf.IndexOutsideBounds:
2111 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
2112 incorrect_small_index = range(-len(ifm_shape), 0)
2113 permutations = [p for p in itertools.permutations(incorrect_large_index)]
2114 permutations.extend(
2115 [p for p in itertools.permutations(incorrect_small_index)]
2116 )
2117 elif error_name == ErrorIf.IndexUsedTwice:
2118 # Create list with a duplicated index
2119 perm_range = list(range(len(ifm_shape)))
2120 index_choice = testGen.rng.choice(range(len(perm_range)))
2121 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
2122 permutations = [p for p in itertools.permutations(perm_range)]
2123
2124 else:
2125 # Get all permutations
2126 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
2127
2128 # Limit to possible permutations from shape dimension or argument setting
2129 limit = min(len(permutations), testGen.args.num_rand_permutations)
2130
2131 # Get random permutation generator that uses all permutations
2132 random_permutations = testGen.rng.permutation(permutations)
2133
2134 # Create list of required amount of permutations
2135 arg_list = [
2136 ("perm{}".format(p), [random_permutations[p].tolist()])
2137 for p in range(limit)
2138 ]
2139 return arg_list
2140
2141 @staticmethod
2142 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
2143 arg_list = []
2144
2145 ifm_shape = shapeList[0]
2146 rank = len(ifm_shape)
2147
2148 for p in range(testGen.args.num_rand_permutations):
2149 start = []
2150 size = []
2151
2152 valid = True
2153
2154 for i in range(rank):
2155 if ifm_shape[i] > 1:
2156 start.append(testGen.randInt(0, ifm_shape[i]))
2157 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
2158
2159 # Invalid slice size?
2160 if size[i] == 0:
2161 valid = False
2162 else:
2163 start.append(0)
2164 size.append(1)
2165
2166 if valid:
2167 # If ERROR_IF test required then incorrect start, size will be returned
2168 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
2169 testGen, error_name, ifm_shape, start, size
2170 )
2171 arg_list.append(("perm{}".format(p), [start, size]))
2172 return arg_list
2173
2174 @staticmethod
2175 def agTile(testGen, opName, shapeList, dtype, error_name=None):
2176 arg_list = []
2177
2178 ifm_shape = shapeList[0]
2179 rank = len(ifm_shape)
2180
2181 for p in range(testGen.args.num_rand_permutations):
2182
2183 # Pick a few random, but small multiple values
2184 # because otherwise this has a tendency to generate
2185 # enormous tensors
2186 multiples = []
2187 for i in range(rank):
2188 if ifm_shape[i] > 1000:
2189 # Multiple of 1 if ifm_shape dimension is large to reduce
2190 # tensor size
2191 multiples.append(1)
2192 elif max(ifm_shape) > 1000:
2193 multiples.append(2)
2194 else:
2195 multiples.append(testGen.randInt(1, 4))
2196 arg_list.append(("perm{}".format(p), [multiples]))
2197
2198 return arg_list
2199
2200 @staticmethod
2201 def agResize(testGen, opName, shapeList, dtype, error_name=None):
2202 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002203 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002204
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002205 def get_aspect_ratio_resize_params():
2206 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
2207 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
2208 invert = testGen.rng.choice((False, True))
2209 letterbox = testGen.rng.choice((False, True))
2210
2211 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
2212 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
2213 scale_y_d = scale_x_d = 1
2214 offset_x = offset_y = 0
2215
2216 if letterbox:
2217 max_border = scale_y_n
2218 border_y = testGen.randInt(low=0, high=max_border)
2219 border_x = 0
2220 else:
2221 # Pillarboxing
2222 border_y = 0
2223 max_border = scale_x_n
2224 border_x = testGen.randInt(low=0, high=max_border)
2225
2226 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2227 offset = (offset_y, offset_x)
2228 border = (border_y, border_x)
2229
2230 return scale, offset, border
2231
2232 def get_upscale_downscale_params():
2233 valid_params = False
2234 while not valid_params:
2235 upscale = testGen.rng.choice((False, True))
2236
2237 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
2238 origin_sampling = testGen.rng.choice((False, True))
2239
2240 if upscale:
2241 shift = testGen.randInt(low=1, high=4)
2242 scale_x_d = scale_y_d = 1
2243 scale_x_n = scale_y_n = (
2244 1 << shift if origin_sampling else 2 << shift
2245 )
2246 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
2247 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
2248 else:
2249 scale_x_n = 1
2250 scale_y_n = 1
2251
2252 # Return list of valid scale_*_d values (max value 4) given input dim shape
2253 def get_valid_denom(ifm_dim):
2254 return [x for x in range(1, 5) if ifm_dim % x == 1]
2255
2256 # Generate list of valid downscale values and choose one randomly
2257 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
2258 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
2259
2260 if not valid_scale_y_ds and not valid_scale_x_ds:
2261 # Bad parameters, skip
2262 continue
2263
2264 if not valid_scale_y_ds:
2265 scale_y_d = 1
2266 else:
2267 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
2268
2269 if not valid_scale_x_ds:
2270 scale_x_d = 1
2271 else:
2272 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
2273
2274 border_x = border_y = 0
2275 offset_y = testGen.randInt(0, 16 * scale_y_n)
2276 offset_x = testGen.randInt(0, 16 * scale_x_n)
2277 valid_params = True
2278
2279 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2280 offset = (offset_y, offset_x)
2281 border = (border_y, border_x)
2282 return scale, offset, border
2283
2284 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002285 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
2286 scale = scale_n / scale_d
2287 if scale > max_scale:
2288 factor = scale / max_scale
2289 new_scale_d = math.ceil(scale_d * factor)
2290 assert scale_n / new_scale_d <= max_scale
2291 scale_d = new_scale_d
2292 return scale_d
2293
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002294 # Scale
2295 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
2296 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
2297
2298 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
2299 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
2300
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002301 scale_y_d = fix_scale_to_max_scale(
2302 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2303 )
2304 scale_x_d = fix_scale_to_max_scale(
2305 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2306 )
2307
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002308 # Offsets and border within the scale
2309 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
2310 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
2311 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
2312 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
2313
2314 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2315 offset = (offset_y, offset_x)
2316 border = (border_y, border_x)
2317 return scale, offset, border
2318
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002319 def get_level_8k_params():
2320 # Create 64x scale - 64/1 to 2048/32
2321 scale_d = testGen.randInt(
2322 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
2323 )
2324 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
2325 # Create half to fifth scaling
2326 scale_d_alt = testGen.randInt(low=2, high=6)
2327 scale_n_alt = 1
2328 switch = testGen.rng.choice((False, True))
2329 if switch:
2330 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
2331 else:
2332 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
2333
2334 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
2335 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
2336 offset = (offset_y, offset_x)
2337 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
2338 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
2339 border = (border_y, border_x)
2340 return scale, offset, border
2341
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002342 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002343 # Exclude illegal {mode, type} configurations. Pick legal output types
2344 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
2345 outputDTypeList = [DType.INT8]
2346 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
2347 outputDTypeList = [DType.INT16]
2348 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
2349 outputDTypeList = [DType.INT32]
2350 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
2351 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01002352 elif dtype == DType.FP16:
2353 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01002354 elif dtype == DType.BF16:
2355 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002356 elif dtype == DType.FP32:
2357 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002358 elif error_name == ErrorIf.WrongInputType:
2359 # If an incorrect input type is used then we set a 'correct'
2360 # output type to avoid other errors
2361 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
2362 else:
2363 continue
2364
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002365 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
2366
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002367 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002368 perm = 0
2369 while perm < testGen.args.num_rand_permutations:
2370 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002371 if not testGen.args.level8k:
2372 _rnd_param_fn = testGen.rng.choice(
2373 (
2374 get_rand_params,
2375 get_upscale_downscale_params,
2376 get_aspect_ratio_resize_params,
2377 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002378 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002379 scale, offset, border = _rnd_param_fn()
2380 else:
2381 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002382
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002383 # Expand params for bounds-checking
2384 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
2385 (offset_y, offset_x) = offset
2386 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002387
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002388 # Make sure output dimensions OH and OW are integers
2389 partial_output_y = (
2390 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
2391 )
2392 partial_output_x = (
2393 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
2394 )
2395 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002396 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002397 if (
2398 partial_output_y % scale_y_d == 0
2399 and partial_output_x % scale_x_d == 0
2400 ):
2401 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002402 if perm > 0:
2403 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002404 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002405 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002406 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002407 while partial_output_y % scale_y_d != 0:
2408 scale_y_d -= 1
2409 while partial_output_x % scale_x_d != 0:
2410 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002411 # Make sure we are still within max scaling
2412 if (
2413 scale_y_n / scale_y_d
2414 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
2415 scale_x_n / scale_x_d
2416 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
2417 # Skip the test as it is using too large a scaling factor
2418 if perm > 0:
2419 perm += 1
2420 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002421
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002422 output_y = partial_output_y // scale_y_d + 1
2423 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002424
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002425 if (
2426 output_y >= testGen.args.max_resize_output_dim
2427 or output_x >= testGen.args.max_resize_output_dim
2428 ) and error_name is None:
2429 # Skip positive test if output dim will be too high
2430 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002431 if not testGen.args.level8k or perm > 0:
2432 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002433 continue
2434
2435 if (
2436 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01002437 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002438 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01002439 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002440 ):
2441 # Output dimensions out of scope
2442 if error_name is not None and perm > 0:
2443 # As long as we have one ERROR_IF test, don't worry
2444 # about creating all the other permutations
2445 perm += 1
2446 continue
2447
2448 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
2449 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01002450 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002451 and output_y - scale_y_d < 1
2452 )
2453 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01002454 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002455 and output_x - scale_x_d < 1
2456 )
2457 ):
2458 # Can't create a negative test with these params as it
2459 # will create invalid output size
2460 if perm > 0:
2461 perm += 1
2462 continue
2463
2464 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
2465 offset = [offset_y, offset_x]
2466 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002467
2468 # Common for all data types
2469 if error_name is not None:
2470 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002471 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002472 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002473 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002474 outputDTypeNew,
2475 ) = TosaErrorIfArgGen.eiResizeErrorIf(
2476 testGen,
2477 error_name,
2478 mode,
2479 dtype,
2480 shapeList,
2481 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002482 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002483 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002484 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002485 )
2486 else:
2487 outputDTypeNew = outputDType
2488
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002489 arg_to_append = (
2490 arg_str.format(
2491 "N" if mode == ResizeMode.NEAREST else "B",
2492 testGen.typeStr(outputDTypeNew),
2493 scale[0],
2494 scale[1],
2495 scale[2],
2496 scale[3],
2497 offset[0],
2498 offset[1],
2499 border[0],
2500 border[1],
2501 ),
2502 [
2503 mode,
2504 scale,
2505 offset,
2506 border,
2507 dtype,
2508 outputDTypeNew,
2509 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002510 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002511 if arg_to_append in arg_list:
2512 # Skip already generated test params
2513 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002514
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002515 # Valid permutation
2516 perm += 1
2517 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002518 return arg_list
2519
2520 @staticmethod
2521 def agTable(testGen, opName, shapeList, dtype, error_name=None):
2522 arg_list = []
2523
2524 if dtype == DType.INT8:
2525 table = np.int32(
2526 testGen.rng.integers(low=-128, high=128, size=[256])
2527 ).tolist()
2528 else: # INT16
2529 table = np.int32(
2530 testGen.rng.integers(low=-32768, high=32768, size=[513])
2531 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07002532 # Make sure all slopes are within REQUIRE min/max 16-bit int
2533 for idx in range(len(table) - 1):
2534 slope = table[idx + 1] - table[idx]
2535 # Alter the next table entry to force the slope to be ok
2536 if slope > 32767:
2537 table[idx + 1] -= slope - 32767
2538 if slope < -32768:
2539 table[idx + 1] -= slope + 32768
2540 slope = table[idx + 1] - table[idx]
2541 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002542 arg_list.append(
2543 (
2544 "",
2545 [table],
2546 )
2547 )
2548 return arg_list
2549
2550 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
2551 # CondIf generates the condition values here.
2552 # Convert to tensors in the build function, along with the
2553 # then and else blocks
2554 arg_list = []
2555
2556 for c in [False, True]:
2557 arg_list.append(("cond{}".format(int(c)), [c]))
2558
2559 return arg_list
2560
2561 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
2562 # While loop: 0 iterations, 1, more than 1
2563 arg_list = []
2564
2565 for iter in [0, 1, 4]:
2566 arg_list.append(("iter{}".format(iter), [iter]))
2567
2568 return arg_list