blob: 914760542c1917fde7a30b66a86bdff28a4ebc19 [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
19
20class TosaQuantGen:
21 """QuantizedInfo random generator helper functions.
22
23 Specify with 'qgen': in the operator defintion.
24 """
25
26 def __init__(self):
27 pass
28
29 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010031
32 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010033 if testGen.args.zeropoint is not None:
34 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035 return testGen.randInt(-128, 128)
36 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010037 if testGen.args.zeropoint is not None:
38 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 return testGen.randInt(0, 256)
40 elif error_name in [
41 ErrorIf.InputZeroPointNotZero,
42 ErrorIf.WeightZeroPointNotZero,
43 ErrorIf.OutputZeroPointNotZero,
44 ]:
45 zero_point = testGen.randInt(-128, 128)
46 if zero_point == 0:
47 zero_point = 1
48 return zero_point
49 return 0
50
51 @staticmethod
52 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010053 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000054 qinfo = [
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 TosaQuantGen.getZeroPoint(testGen, dtype),
57 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010058 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000059 qinfo = [
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
62 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010063 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000064 qinfo = [
65 TosaQuantGen.getZeroPoint(testGen, dtype),
66 TosaQuantGen.getZeroPoint(testGen, dtype),
67 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068 return qinfo
69
70 @staticmethod
71 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010072 if isinstance(dtype_or_dtypeList, list):
73 # a list of [input, weights, accumulator] dtypes
74 dtypeList = dtype_or_dtypeList
75 else:
76 # an int, [input, weights, accumulator] dtypes are the same
77 dtypeList = [dtype_or_dtypeList] * 3
78
79 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000080 qinfo = [
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
82 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
83 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000085 qinfo = [
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
87 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
88 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000090 qinfo = [
91 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
92 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
93 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 return qinfo
95
96 @staticmethod
97 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000099 qinfo = [
100 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
101 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
102 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000104 qinfo = [
105 TosaQuantGen.getZeroPoint(testGen, dtype),
106 TosaQuantGen.getZeroPoint(testGen, dtype),
107 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108 return qinfo
109
110 @staticmethod
111 def computeMultiplierAndShift(scaleFp, scale32):
112 # Derived from computeMultiplierAndShiftTosaScale32
113 # Provide a floating-point scaling factor and the scale32 parameter
114 # to compute the multiplier and shift
115
116 if scale32:
117 scaleBits = 31
118 else:
119 scaleBits = 15
120
121 m, shift = math.frexp(scaleFp)
122
123 if scaleFp < 0.0:
124 m = -m
125
126 multiplier = round(m * (1 << scaleBits))
127 assert multiplier <= (1 << scaleBits)
128
129 if multiplier == (1 << scaleBits):
130 multiplier = multiplier // 2
131 shift = shift + 1
132
133 shift = (-shift) + scaleBits
134 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
135 # scaleFp, scaleBits, m, multiplier, shift))
136
137 # Adjust multiplier such that shift is in allowed value range.
138 if shift == 0:
139 multiplier = multiplier // 4
140 shift = shift + 2
141 elif shift == 1:
142 multiplier = multiplier // 2
143 shift = shift + 1
144 elif shift == 63:
145 multiplier = multiplier * 2
146 shift = shift - 1
147
148 assert multiplier <= (1 << scaleBits)
149 assert shift >= 2 and shift <= 62
150
151 return multiplier, shift
152
153
154class TosaTensorGen:
155 """Tensor generators create a shape list for the placeholder and const tensor
156 data operands for the operator.
157
158 The actual random data is generated separately for each test.
159 """
160
161 def __init__(self):
162 pass
163
164 @staticmethod
165 def tgBasic(testGen, opName, rank, error_name=None):
166 pl, const = opName["operands"]
167 shape = testGen.makeShape(rank)
168
169 # Constrict the overall size of the shape when creating ERROR_IF tests
170 if error_name:
171 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
172
173 shape_list = []
174 for i in range(pl + const):
175 shape_list.append(shape.copy())
176
Luke Huttona4e48ca2023-02-22 11:53:48 +0000177 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 if error_name == ErrorIf.RankMismatch:
179 if rank == 1 and i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
181 elif i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
183
184 return shape_list
185
186 @staticmethod
187 def tgNHWC(testGen, opName, rank, error_name=None):
188 pl, const = opName["operands"]
189
190 if error_name != ErrorIf.WrongRank:
191 assert rank == 4
192
193 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000194 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name and error_name != ErrorIf.MaxDimExceeded:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
207 def tgScatter(testGen, opName, rank, error_name=None):
208 pl, const = opName["operands"]
209
210 assert pl == 2
211 assert const == 0
212 if error_name != ErrorIf.WrongRank:
213 assert rank == 3
214
215 values_in_shape = testGen.makeShape(rank)
216
217 # ignore max batch size if target shape is set
218 if testGen.args.max_batch_size and not testGen.args.target_shapes:
James Ward30124a82023-02-02 14:56:33 +0000219 values_in_shape[0] = min(values_in_shape[0], testGen.args.max_batch_size)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100220
221 W = testGen.randInt(
222 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
223 )
224 # Constrict W if one dimension is too large to keep tensor size reasonable
225 if max(values_in_shape) > 5000:
226 W = testGen.randInt(0, 16)
227
228 input_shape = [values_in_shape[0], W, values_in_shape[2]]
229
230 shape_list = []
231 shape_list.append(values_in_shape.copy())
232 shape_list.append(input_shape.copy())
233
234 return shape_list
235
236 @staticmethod
237 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
238 shape = testGen.makeShape(rank)
239
240 pl, const = op["operands"]
241
242 shape_list = []
243
244 # Choose one of the inputs to broadcast
245 # Note: Simplifies OutputShaper code if we don't change first shape for errors
246 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Jerry Ge135c9552023-05-23 20:59:32 +0000247 fuzz_idx = testGen.randInt(0, rank)
248
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100249 for i in range(pl + const):
250 shape_bcast = shape.copy()
251
Jerry Ge135c9552023-05-23 20:59:32 +0000252 # To test broadcasting, the chosen fuzz index dimension should not be 1
253 if shape_bcast[fuzz_idx] == 1:
254 shape_bcast[fuzz_idx] += 1
255
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100256 # If the chosen input, pick a random index to broadcast
257 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000258 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100259 # Add one rank to the shape (or more for rank of 1)
260 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
261 shape_bcast = np.concatenate(
262 (shape_bcast, testGen.makeShape(extra_ranks))
263 )
264 if rank != 1:
265 # Either keep the extra rank, or remove it
266 new_len = testGen.rng.choice([-2, len(shape_bcast)])
267 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000268 elif error_name == ErrorIf.BroadcastShapesMismatch:
269 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100270 else:
271 shape_bcast[fuzz_idx] = 1
272
273 shape_list.append(shape_bcast)
274
275 return shape_list
276
277 @staticmethod
278 def tgConv2D(testGen, op, rank, error_name=None):
279 pl, const = op["operands"]
280
281 if error_name != ErrorIf.WrongRank:
282 assert rank == 4
283
284 # IFM dimensions are NHWC
285 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000286 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100287
288 # Constrict the overall size of the shape when creating ERROR_IF tests
289 if error_name:
290 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
291 ifm_shape, max_dim=24, max_items=10000
292 )
293
294 # Get the filter height/width from the operator parameters
295 filter_hw = op["filter"]
296
297 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000298 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100299
300 # The filter dimensions are OHWI
301 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
302
303 # The bias is OC
304 bias_shape = np.asarray([ofm_depth])
305
306 return [ifm_shape, filter_shape, bias_shape]
307
308 @staticmethod
309 def tgConv3D(testGen, op, rank, error_name=None):
310 pl, const = op["operands"]
311
312 if error_name != ErrorIf.WrongRank:
313 assert rank == 5
314
315 # IFM dimensions are NDHWC
316 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000317 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100318
319 # Constrict the overall size of the shape when creating ERROR_IF tests
320 if error_name:
321 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
322 ifm_shape, max_dim=24, max_items=10000
323 )
324
325 # Get the filter depth/height/width from the operator parameters
326 filter_dhw = op["filter"]
327
328 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000329 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100330
331 # The filter dimensions are ODHWI
332 filter_shape = np.asarray(
333 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
334 )
335
336 # The bias is OC
337 bias_shape = np.asarray([ofm_channel])
338
339 return [ifm_shape, filter_shape, bias_shape]
340
341 @staticmethod
342 def tgTransposeConv2D(testGen, op, rank, error_name=None):
343 pl, const = op["operands"]
344
345 if error_name != ErrorIf.WrongRank:
346 assert rank == 4
347
348 # IFM dimensions are NHWC
349 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000350 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100351
352 # Constrict the overall size of the shape when creating ERROR_IF tests
353 if error_name:
354 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
355 ifm_shape, max_dim=24, max_items=10000
356 )
357
358 # Get the filter height/width from the operator parameters
359 filter_hw = op["filter"]
360
361 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000362 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100363
364 # The filter dimensions are OHWI
365 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
366
367 # The bias is OC
368 bias_shape = np.asarray([ofm_depth])
369
370 return [ifm_shape, filter_shape, bias_shape]
371
372 @staticmethod
373 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
374 pl, const = op["operands"]
375
376 if error_name != ErrorIf.WrongRank:
377 assert rank == 4
378 assert pl == 1 and const == 2
379
380 # IFM dimensions are NHWC
381 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000382 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100383
384 # Constrict the overall size of the shape when creating ERROR_IF tests
385 if error_name:
386 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
387 ifm_shape, max_dim=24, max_items=10000
388 )
389
390 # Get the filter height/width from the operator parameters
391 # Filter is KH, HW, C, M
392 filter_hw = op["filter"]
393
394 # Generate a random OFM depth, but don't let it get too big because
395 # the output depth is M * C
396 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000397 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100398 ) + 1
399
400 # The filter dimensions are HWCM
401 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
402
403 # The bias is M * C
404 bias_shape = np.asarray([ifm_shape[3] * filter_m])
405
406 return [ifm_shape, filter_shape, bias_shape]
407
408 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000409 def tgFFT2d(testGen, op, rank, error_name=None):
410 pl, const = op["operands"]
411
412 if error_name != ErrorIf.WrongRank:
413 assert rank == 3
414 assert pl == 2 and const == 0
415
416 # IFM dimensions are NHW
417 ifm_shape = testGen.makeShape(rank)
418
419 # Select nearest lower power of two from input height and width
420 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
421 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
422
423 # Constrict the overall size of the shape when creating ERROR_IF tests
424 if error_name:
425 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
426
427 # Generate an invalid kernel that is not a power of two
428 if error_name == ErrorIf.KernelNotPowerOfTwo:
429 inc_h = 2 if ifm_shape[1] == 1 else 1
430 inc_w = 2 if ifm_shape[2] == 1 else 1
431 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
432 selected_inc = testGen.rng.choice(inc_choices)
433 ifm_shape[1] += selected_inc[0]
434 ifm_shape[2] += selected_inc[1]
435
436 ifm_shape = testGen.constrictBatchSize(ifm_shape)
437
438 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
439 if error_name == ErrorIf.FFTInputShapeMismatch:
440 modify_shape = testGen.rng.choice([0, 1])
441 # Only modify kernel (H, W)
442 modify_dim = testGen.rng.choice([1, 2])
443 ifm_shapes[modify_shape][modify_dim] *= 2
444
445 return [ifm_shapes[0], ifm_shapes[1]]
446
447 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000448 def tgRFFT2d(testGen, op, rank, error_name=None):
449 pl, const = op["operands"]
450
451 if error_name != ErrorIf.WrongRank:
452 assert rank == 3
453 assert pl == 1 and const == 0
454
455 # IFM dimensions are NHW
456 ifm_shape = testGen.makeShape(rank)
457
458 # Select nearest lower power of two from input height and width
459 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
460 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
461
462 # Constrict the overall size of the shape when creating ERROR_IF tests
463 if error_name:
464 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
465
466 # Generate an invalid kernel that is not a power of two
467 if error_name == ErrorIf.KernelNotPowerOfTwo:
468 # We must increment by 2 if current size is 1
469 inc_h = 2 if ifm_shape[1] == 1 else 1
470 inc_w = 2 if ifm_shape[2] == 1 else 1
471 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
472 selected_inc = testGen.rng.choice(inc_choices)
473 ifm_shape[1] += selected_inc[0]
474 ifm_shape[2] += selected_inc[1]
475
James Ward30124a82023-02-02 14:56:33 +0000476 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000477
478 return [ifm_shape]
479
480 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100481 def tgFullyConnected(testGen, op, rank, error_name=None):
482 pl, const = op["operands"]
483
484 if error_name != ErrorIf.WrongRank:
485 assert rank == 2
486
487 input_shape = testGen.makeShape(rank)
488
489 # Constrict the overall size of the shape when creating ERROR_IF tests
490 if error_name:
491 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
492
493 filter_oc = testGen.rng.integers(
494 low=testGen.args.tensor_shape_range[0],
495 high=testGen.args.tensor_shape_range[1],
496 size=1,
497 )[0]
498 filter_shape = np.asarray([filter_oc, input_shape[1]])
499
500 bias_shape = np.asarray([filter_oc])
501
502 return [input_shape, filter_shape, bias_shape]
503
504 @staticmethod
505 def tgMatmul(testGen, op, rank, error_name=None):
506 pl, const = op["operands"]
507
508 if error_name != ErrorIf.WrongRank:
509 assert rank == 3
510 assert pl == 2 and const == 0
511
512 a_shape = testGen.makeShape(rank)
513
514 # Constrict the overall size of the shape when creating ERROR_IF tests
515 if error_name:
516 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
517
518 # Get a random number for b_oc even if target shape is defined
519 b_oc = np.int32(
520 testGen.rng.integers(
521 low=testGen.args.tensor_shape_range[0],
522 high=testGen.args.tensor_shape_range[1],
523 size=1,
524 )
525 )[0]
526 # If N or H is large let b_oc be 1 to reduce output tensor size
527 if max(a_shape) > 1000:
528 b_oc = 1
529
530 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
531 return [a_shape, b_shape]
532
533 @staticmethod
534 def tgConcat(testGen, opName, rank, error_name=None):
535 pl, const = opName["operands"]
536 shape = testGen.makeShape(rank)
537
538 # Create extra tensors to concat.
539 # Take into account value of pl when getting maximum number of concats
540 num_tensors = testGen.randInt(0, 4)
541 shape_list = []
542 for i in range(pl + const + num_tensors):
543 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
544 remove = testGen.rng.choice([True, False])
545 wrongShape = shape.copy()
546
547 if remove and len(shape) > 1:
548 wrongShape = wrongShape[1:]
549 else:
550 wrongShape = list(wrongShape)
551 wrongShape.append(testGen.rng.integers(1, 10))
552
553 shape_list.append(wrongShape)
554 else:
555 shape_list.append(shape.copy())
556
557 return shape_list
558
559 @staticmethod
560 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
561 if error_name in [
562 ErrorIf.AxisSmallerZero,
563 ErrorIf.AxisLargerRank,
564 ErrorIf.ConcatInputRankMismatch,
565 ]:
566 return shapeList
567
568 # Split concat shape along axis to allow for multiple const inputs
569 # without making too many large tensors
570 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
571 # If axis can't be split we still need to invalidate other dimensions
572 if error_name == ErrorIf.ConcatInputDimMismatch:
573 for shape in shapeList[1:]:
574 # Negative test shapeLists are created individually for each test,
575 # so no need to copy the shape before altering it.
576 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
577 return shapeList
578
579 # Create copy of shape we are going to split (so we don't alter shapeList)
580 shape = shapeList[0].copy()
581 # Add original shape as first input
582 new_shapeList = [shape.copy()]
583 length_on_axis = shape[axis]
584 remaining_length = length_on_axis
585 for i in range(len(shapeList) - 2):
586 # Calculate split on axis and remaining value
587 split_shape_val = int(shape[axis] / 2)
588 remaining_length = remaining_length - split_shape_val
589
590 # Append new shape, and set remaining shape
591 shape[axis] = split_shape_val
592 new_shapeList.append(shape.copy())
593
594 # invalidate dimensions
595 if error_name == ErrorIf.ConcatInputDimMismatch:
596 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
597 else:
598 shape[axis] = remaining_length
599
600 if i == len(shapeList) - 3:
601 new_shapeList.append(shape.copy())
602
603 return new_shapeList
604
605
606class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100607 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100608
609 def __init__(self):
610 pass
611
Jeremy Johnson1271c442023-09-05 11:39:26 +0100612 class TVGInfo:
613 """Enhanced tensor values information including data gen dict."""
614
615 def __init__(self, tensorList, dataGenDict):
616 self.tensorList = tensorList
617 self.dataGenDict = dataGenDict
618
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100619 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000620 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100621 pCount, cCount = op["operands"]
622
623 tens = []
624 tens.extend(
625 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
626 )
627 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
628
629 return tens
630
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100631 # Default high value for random numbers
632 TVG_FLOAT_HIGH_VALUE = {
633 DType.FP32: (1 << 128) - (1 << (127 - 23)),
634 DType.FP16: (1 << 16) - (1 << (15 - 10)),
635 DType.BF16: (1 << 128) - (1 << (127 - 7)),
636 }
637
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100638 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000639 def _get_data_range(testGen, dtype, highValueLookup):
640 if dtype in highValueLookup:
641 data_range = testGen.getDTypeRange(dtype, high_inclusive=True)
642 high_val = highValueLookup[dtype]
643 # Set the values to something that won't produce infinity whilst
644 # respecting the default ranges if less than the high value
645 return [
646 max(-high_val, data_range[0]),
647 min(high_val, data_range[1]),
648 ]
649 return None
650
651 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100652 def tvgLazyGenDefault(
653 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
654 ):
655 # Variable inputs versus constants
656 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000657 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100658
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100659 if (
660 error_name is not None
661 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100662 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100663 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000664 # Fall back to internal data gen when dealing with unsupported types or ops
665 data_range = argsDict["data_range"] if "data_range" in argsDict else None
666 for idx, info in enumerate(zip(shapeList, dtypeList)):
667 shape, dtype = info
668 # Ignore lazy data gen option and create data array using any range limits
669 arr = testGen.getRandTensor(shape, dtype, data_range)
670 if idx < pCount:
671 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
672 else:
673 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100674
Jeremy Johnson1271c442023-09-05 11:39:26 +0100675 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
676
677 # Create data generator meta-data
678 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100679 tens_data = {
680 "version": "0.1",
681 "tensors": {},
682 }
683 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100684 for idx, shape in enumerate(shapeList):
685
686 tens_meta = {}
687 tens_meta["generator"] = gtu.DataGenType(dg_type).name
688 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
689 tens_meta["shape"] = [int(i) for i in shape]
690 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100691 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100692
693 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100694 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100695 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100696 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100697
698 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
699 info = {}
700 # TODO - generate seed for this generator based on test
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100701 info["rng_seed"] = 42
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100702 if "data_range" in argsDict:
703 data_range = argsDict["data_range"]
704 else:
705 data_range = testGen.getDTypeRange(
706 dtypeList[idx], high_inclusive=True
707 )
708 info["range"] = [str(v) for v in data_range]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100709 tens_meta["pseudo_random_info"] = info
710 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
711 info = {}
712 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100713 info["ks"] = int(argsDict["ks"])
714 if "acc_type" in argsDict:
715 # Convert type number into JSON name
716 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
717 "json"
718 ]
719 if "kernel" in argsDict:
720 info["kernel"] = [int(k) for k in argsDict["kernel"]]
721 if "axis" in argsDict:
722 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100723 tens_meta["dot_product_info"] = info
724 else:
725 # TODO - other data gen type
726 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100727
728 # Using the finished generate config meta data - generate the data if
729 # needed and assign a tensor name from the serializer
730
731 # Need to generate data when not lazy or for the bias tensor as we need
732 # to work out if the bias data is non-zero for compliance
733 if not testGen.args.lazy_data_gen or (
734 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
735 ):
736 # Give this tensor a temporary name until we get one from the serializer
737 temp_name = f"placeholder_{idx}"
738 dg_tens_meta[temp_name] = tens_meta
739 # Create data now using the temporary name to access meta details
740 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
741 # Remove the item as we will give it the correct name later
742 del dg_tens_meta[temp_name]
743
744 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
745 # The KS value used by compliance verification is altered when the
746 # bias data is non-zero
747 if max(abs(data)) > 0.0:
748 argsDict["ksb"] = argsDict["ks"] + 1
749
750 if testGen.args.lazy_data_gen:
751 data = None
752
753 if tens_meta["input_type"] == "VARIABLE":
754 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
755 else:
756 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
757
758 tens_ser_list.append(tens)
759 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100760 dg_tens_meta[tens.name] = tens_meta
761
Jeremy Johnson1271c442023-09-05 11:39:26 +0100762 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
763
764 @staticmethod
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000765 def tvgNegate(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100766 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000767 # Integer test
768 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100769 pCount, cCount = op["operands"]
770 assert (
771 pCount == 1 and cCount == 0
772 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100773 # Must create tensors with values within accumulator (int32) negatable
774 # range
775 max_val = (1 << 31) - 1
776 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100777 arr = np.int32(
778 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
779 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000780 tens_ser_list = []
781 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100782 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
783 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000784 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100785 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000786 # ERROR_IF or floating point test
787 return TosaTensorValuesGen.tvgLazyGenDefault(
788 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100789 )
790
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000791 # Set the data range to half the largest value
792 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
793 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
794 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
795 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
796 }
797
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100798 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000799 def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100800 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000801 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100802 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000803 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100804 pCount, cCount = op["operands"]
805 assert (
806 pCount == 2 and cCount == 0
807 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000808 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100809 add = op["op"] == Op.ADD
810 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
811 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
812 if add:
813 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
814 else:
815 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
816
817 # Work out the saturation limits
818 max_i32 = (1 << 31) - 1
819 min_i32 = -(1 << 31)
820 max_arr = np.full(shapeList[1], max_i32)
821 min_arr = np.full(shapeList[1], min_i32)
822
823 # Find how much values exceed the maximum/minimums
824 sat_max_arr = np.maximum(res_arr - max_arr, 0)
825 sat_min_arr = np.minimum(res_arr - min_arr, 0)
826
827 if not add:
828 # Swap saturation values and negate values as we need to perform opposite operations
829 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
830
831 # Create new array of unsaturated values by clipping values as needed
832 b_unsat_arr = b_arr
833 if (sat_max_arr != 0).any():
834 # Clip values that cause saturation
835 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
836 # Reduce axes in unsaturated tensor to match original tensor
837 for axis, dim in enumerate(b_arr.shape):
838 if dim != b_unsat_arr.shape[axis]:
839 assert (
840 dim == 1
841 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
842 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
843
844 if (sat_min_arr != 0).any():
845 # Clip values that cause saturation
846 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
847 # Reduce axes in unsaturated tensor to match original tensor
848 for axis, dim in enumerate(b_arr.shape):
849 if dim != b_unsat_arr.shape[axis]:
850 assert (
851 dim == 1
852 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
853 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
854
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000855 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100856 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
857 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000858 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100859 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
860 )
861
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000862 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100863 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000864 # ERROR_IF or floating point test
865 data_range = TosaTensorValuesGen._get_data_range(
866 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
867 )
868 if data_range:
869 argsDict["data_range"] = data_range
870
871 return TosaTensorValuesGen.tvgLazyGenDefault(
872 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100873 )
874
875 @staticmethod
876 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000877 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100878 ):
879 if dtypeList[0] in (
880 DType.INT32,
881 DType.INT16,
882 DType.INT8,
883 ):
884 # Limit input tensors with cond_if_binary or while_loop to stop
885 # saturation of add/sub ops with int32 and keep all logical shift
886 # values between 0 to 31 for int16 or int8
887 pCount, cCount = op["operands"]
888 pRemain = pCount
889 placeholders = []
890 for idx, shape in enumerate(shapeList[:]):
891 if dtypeList[0] == DType.INT32:
892 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
893 else:
894 arr = np.int32(
895 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
896 )
897 if pRemain > 0:
898 placeholders.append(
899 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
900 )
901 pRemain -= 1
902 else:
903 placeholders.append(
904 testGen.ser.addConst(shape, dtypeList[idx], arr)
905 )
906
907 return placeholders
908 else:
909 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000910 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100911 )
912
913 @staticmethod
914 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000915 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100916 ):
917 pCount, cCount = op["operands"]
918 # Force value of operand[1] to be within [0, num_bits]
919 assert (
920 pCount == 2 and cCount == 0
921 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
922
923 placeholders = []
924 for idx, shape in enumerate(shapeList[:]):
925 if idx == 1:
926 if dtypeList[idx] == DType.INT8:
927 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
928 elif dtypeList[idx] == DType.INT16:
929 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
930 elif dtypeList[idx] == DType.INT32:
931 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
932 elif error_name == ErrorIf.WrongInputType:
933 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
934 else:
935 raise Exception("OpArithmeticRightShift: invalid input dtype")
936 else:
937 arr = testGen.getRandTensor(shape, dtypeList[idx])
938 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
939
940 return placeholders
941
942 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000943 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100944 # Set datatype of condition tensor to boolean
945 dtypeList[0] = DType.BOOL
946
947 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000948 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100949 )
950
951 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000952 def tvgIntDiv(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100953 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000954 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100955 pCount, cCount = op["operands"]
956 assert (
957 pCount == 2 and cCount == 0
958 ), "Op.INTDIV must have 2 placeholders, 0 consts"
959
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000960 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100961
962 # Two invalid cases for Op.INTDIV:
963 # 1. divisor == 0
964 # 2. dividend == -(1<<31) and divisor == -1
965 while True:
966 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
967 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
968
969 if (divisor_arr == 0).any():
970 continue
971
972 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
973 continue
974
975 break
976
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000977 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100978 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
979 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000980 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100981 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
982 )
983
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000984 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100985 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000986 return TosaTensorValuesGen.tvgLazyGenDefault(
987 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100988 )
989
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100990 # Set the data range to the square root of the largest value
991 TVG_FLOAT_HIGH_VALUE_MUL = {
992 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
993 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
994 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
995 }
996
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100997 @staticmethod
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100998 def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
999 if error_name is not None or dtypeList[0] in (
1000 DType.FP16,
1001 DType.BF16,
1002 DType.FP32,
1003 ):
1004 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001005 data_range = TosaTensorValuesGen._get_data_range(
1006 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
1007 )
1008 if data_range:
1009 argsDict["data_range"] = data_range
1010
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001011 return TosaTensorValuesGen.tvgLazyGenDefault(
1012 testGen, opName, dtypeList, shapeList, argsDict, error_name
1013 )
1014 else:
1015 # Integer test
1016 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001017 pCount, cCount = op["operands"]
1018 assert (
1019 pCount == 2 and cCount == 0
1020 ), "Op.MUL must have 2 placeholders, 0 consts"
1021
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001022 tens_ser_list = []
1023
1024 # Make sure multiply result in int32 range
1025 shift = argsDict["shift"]
1026 if dtypeList[0] == DType.INT8:
1027 num_bits = 8
1028 elif dtypeList[0] == DType.INT16:
1029 num_bits = 16
1030 elif dtypeList[0] == DType.INT32:
1031 num_bits = 32
1032 elif error_name == ErrorIf.WrongInputType:
1033 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001034 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001035 raise Exception("OpMul: invalid input dtype")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001036
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001037 for idx, shape in enumerate(shapeList[:]):
1038 low = -(2 ** (num_bits - 1))
1039 high = (2 ** (num_bits - 1)) - 1
1040
1041 a_arr = np.int32(
1042 testGen.rng.integers(low=low, high=high, size=shapeList[0])
1043 )
1044 b_arr = np.int32(
1045 testGen.rng.integers(low=low, high=high, size=shapeList[1])
1046 )
1047
1048 i = 0
1049 while True:
1050
1051 a_arr_64 = a_arr.astype(np.int64)
1052 b_arr_64 = b_arr.astype(np.int64)
1053
1054 if shift > 0:
1055 rounding = 1 << (shift - 1)
1056 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001057 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001058 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001060 if (result_arr > -(2**31)).all() and (
1061 result_arr <= ((2**31) - 1)
1062 ).all():
1063 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001064
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001065 i = i + 1
1066 a_arr = a_arr // 2
1067 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001068
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001069 tens_ser_list.append(
1070 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001071 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001072 tens_ser_list.append(
1073 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1074 )
1075
1076 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001077
1078 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001079 def tvgConcat(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001080 count = len(shapeList) - testGen.args.num_const_inputs_concat
1081 if count < 1:
1082 count = 1
1083 if testGen.args.num_const_inputs_concat == 0:
1084 count = len(shapeList)
1085
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001086 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001087 testGen, shapeList, argsDict["axis"], error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001088 )
1089
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001090 tens_ser_list = []
1091 tens_ser_list.extend(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001092 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
1093 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001094 tens_ser_list.extend(
1095 testGen.buildConstTensors(shapeList[count:], dtypeList[count:])
1096 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001097
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001098 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001099
1100 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001101 def tvgLogicalShift(
1102 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1103 ):
1104 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001105 pCount, cCount = op["operands"]
1106 assert (
1107 pCount == 2 and cCount == 0
1108 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
1109 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1110 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001111 tens_ser_list = []
1112 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001113 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1114 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001115 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001116 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1117 )
1118
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001119 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001120
1121 @staticmethod
Jeremy Johnsona0150012023-11-15 15:52:06 +00001122 def tvgEqual(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1123 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1124 # Integer
1125 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001126 pCount, cCount = op["operands"]
1127 assert (
1128 pCount == 2 and cCount == 0
1129 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001130
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001131 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1132 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001133
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001134 # Using random numbers means that it will be very unlikely that
1135 # there are any matching (equal) values, therefore force that
1136 # there are twice the number of matching values as the tensor rank
1137 for num in range(0, len(shapeList[0]) * 2):
1138 a_index = []
1139 b_index = []
1140 # Choose an index in each axis for the whole shape
1141 for axis in range(0, len(shapeList[0])):
1142 # Index can be up to the largest dimension in both shapes
1143 index = np.int32(
1144 testGen.rng.integers(
1145 0, max(shapeList[0][axis], shapeList[1][axis])
1146 )
1147 )
1148 # Reduce the index down to a shape's dim for broadcasting
1149 a_index.append(min(shapeList[0][axis] - 1, index))
1150 b_index.append(min(shapeList[1][axis] - 1, index))
1151
1152 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1153
Jeremy Johnsona0150012023-11-15 15:52:06 +00001154 tens_ser_list = []
1155 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001156 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1157 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001158 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001159 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1160 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001161 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001162 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001163 # ERROR_IF or floating point test
1164 return TosaTensorValuesGen.tvgLazyGenDefault(
1165 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001166 )
1167
1168 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001169 def tvgReduceSum(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001170 if dtypeList[0] == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001171 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001172 pCount, cCount = op["operands"]
1173 assert (
1174 pCount == 1 and cCount == 0
1175 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1176 # Limit values so that the sum cannot exceed the range of an int32 during
1177 # summation of any axis
1178 range_val = int((1 << 31) / max(shapeList[0]))
1179 values_arr = np.int32(
1180 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1181 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001182 tens_ser_list = []
1183 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001184 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1185 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001186 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001187 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001188 # ERROR_IF or dot product floating point test
1189 return TosaTensorValuesGen.tvgLazyGenDefault(
1190 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001191 )
1192
1193
1194class TosaArgGen:
1195 """Argument generators create exhaustive or random lists of attributes for
1196 operators that take attributes or other parameters.
1197
1198 The return value is a list of (descriptive_name, [arglist]) tuples where
1199 the descriptive_name is appended to the test name and the arglist is expanded
1200 as arguments to the operator build function.
1201 """
1202
1203 def __init__(self):
1204 pass
1205
1206 @staticmethod
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001207 def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001208 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001209 if (
1210 error_name is None
1211 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1212 and gtu.dtypeIsSupportedByCompliance(dtype)
1213 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001214 if dtype in [DType.FP16, DType.FP32, DType.BF16]:
1215 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1216 else:
1217 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1218 else:
1219 # Error test or No data generator types listed - assume random
1220 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1221
1222 # Expand arg list with other data generator types
1223 new_arg_list = []
1224 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001225 for arg_str, args_dict in arg_list:
1226 args_dict["dg_type"] = dg_type
Jeremy Johnson1271c442023-09-05 11:39:26 +01001227 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
1228 # Default test
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001229 new_arg_list.append((arg_str, args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001230
1231 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1232 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001233 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001234 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001235 shape_info = (
1236 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1237 if "shape" in args_dict
1238 else ""
1239 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001240 print(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001241 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001242 )
1243 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001244 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001245 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001246 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001247
1248 for s in testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS:
1249 new_arg_str = f"{arg_str}_s{s}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001250 new_args_dict = args_dict.copy()
1251 new_args_dict["s"] = s
1252 new_arg_list.append((new_arg_str, new_args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001253
1254 return new_arg_list
1255
1256 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001257 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1258 """A trivial argument generator for operators that don't take any
1259 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001260 arg_list = TosaArgGen._add_data_generators(
1261 testGen,
1262 opName,
1263 dtype,
1264 [("", {})],
1265 error_name,
1266 )
1267 # Return list of tuples: (arg_str, args_dict)
1268 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001269
1270 @staticmethod
1271 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1272 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001273 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001274 shape = shapeList[0]
1275
1276 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001277 # Set too small axis
1278 axes = [testGen.rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001279 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001280 # Set too large axis
1281 axes = [testGen.rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001282 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001283 # Create tests for each dimension
1284 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001285
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001286 opid = testGen.TOSA_OP_LIST[opName]["op"]
1287
1288 for a in axes:
1289 args_dict = {"axis": int(a)}
1290 if opid == Op.REDUCE_SUM:
1291 args_dict["dot_products"] = gtu.product(shape)
1292 args_dict["shape"] = shape
1293 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1294 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1295
1296 arg_list.append(("axis{}".format(a), args_dict))
1297
1298 arg_list = TosaArgGen._add_data_generators(
1299 testGen,
1300 opName,
1301 dtype,
1302 arg_list,
1303 error_name,
1304 )
1305 # Return list of tuples: (arg_str, args_dict)
1306 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001307
1308 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001309 def _calculate_sparsity(num_tests, sparsity_factor):
1310 sparsity = num_tests // sparsity_factor + 1
1311 # If there are only a small number of tests, just select them all
1312 if sparsity < 13:
1313 sparsity = 1
1314 # To get a variety of parameter combinations sparsity should not be a
1315 # multiple of 2, 3 or 5
1316 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1317 sparsity += 1
1318 return sparsity
1319
1320 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001321 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001322 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001323 arg_list = []
1324
Jeremy Johnson0c716862023-04-13 17:18:19 +01001325 if testGen.args.level8k and error_name is not None:
1326 # Don't produce negative large tests
1327 return arg_list
1328
1329 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001330 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001331 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001332 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001333
Jeremy Johnson1271c442023-09-05 11:39:26 +01001334 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001335
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001336 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01001337 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001338 depthwise = opName.startswith("depthwise")
1339
1340 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001341 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001342 if error_name != ErrorIf.WrongRank:
1343 assert len(ifm_shape) == rank
1344 assert len(filter_shape) == rank
1345
Jeremy Johnson0c716862023-04-13 17:18:19 +01001346 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001347 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001348 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01001349 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001350 # compliance size - KS
1351 k_size = gtu.product(k_shape)
1352 if not depthwise:
1353 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001354
Jeremy Johnson0c716862023-04-13 17:18:19 +01001355 if not testGen.args.level8k:
1356 # Generate comprehensive argument lists
1357 # - except for named errors, which use specific invalid value(s)
1358 if error_name == ErrorIf.PadSmallerZero:
1359 p_vals = [testGen.rng.choice(range(-5, 0))]
1360 else:
1361 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1362 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1363 if error_name == ErrorIf.StrideSmallerOne:
1364 # Can't use stride=0, as it is used to derive output shape, as a divisor
1365 s_vals = [testGen.rng.choice(range(-5, 0))]
1366 else:
1367 # Stride must be greater than 1 to force non-integer error
1368 startStride = (
1369 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001370 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001371 s_vals = [
1372 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1373 ]
1374 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1375 if error_name == ErrorIf.DilationSmallerOne:
1376 d_vals = [testGen.rng.choice(range(-5, 1))]
1377 else:
1378 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1379 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001380
Jeremy Johnson0c716862023-04-13 17:18:19 +01001381 if not error_name and testGen.args.oversize:
1382 # add some oversize argument values
1383 if max(ifm_shape) < 64:
1384 bigPadding = 9
1385 paddings.update(
1386 {
1387 x
1388 for x in itertools.product(
1389 *([[0, bigPadding]] * (k_rank * 2))
1390 )
1391 }
1392 )
1393 bigStride = 8
1394 strides.update(
1395 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1396 )
1397 bigDilation = 7
1398 dilations.update(
1399 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1400 )
1401 max_dim_size = None
1402
1403 # There are too many parameter combinations, so generate them sparsely,
1404 # very sparse for negative tests
1405 sparsity_factor = 2 if error_name else 120
1406 sparsity = TosaArgGen._calculate_sparsity(
1407 len(paddings) * len(strides) * len(dilations), sparsity_factor
1408 )
1409 else:
1410 # Only test 8k levels boundaries
1411 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1412 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1413 bigPadding = bigKernel
1414
1415 dilation_shape = [1] * k_rank
1416 pad_shape = [0] * k_rank * 2
1417 if conv3d:
1418 # Small stride apart from for big kernel (see below) to keep
1419 # tensor size/calculation small
1420 stride_shape = [1] * k_rank
1421 for idx in range(k_rank):
1422 pad_offset = idx * 2
1423 if k_shape[idx] == bigKernel:
1424 # Padding shape needs to account for tensor shape
1425 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1426 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1427 # Big stride to reduce output size
1428 stride_shape[idx] = bigKernel
1429 else:
1430 # Account for kernel size
1431 pad_shape[pad_offset] = k_shape[idx] - 1
1432 else:
1433 # Always have a large stride with extra padding and dilation to keep
1434 # tensor calculation reasonable
1435 stride_shape = [bigKernel] * k_rank
1436 for idx in range(k_rank):
1437 # Dilation shape must account for kernel size
1438 dilation_shape[idx] = bigKernel // k_shape[idx]
1439 # Padding shape needs to accommodate tensor/kernel & dilation
1440 pad_offset = idx * 2
1441 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1442 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1443
1444 strides = {tuple(stride_shape)}
1445 dilations = {tuple(dilation_shape)}
1446 paddings = {tuple(pad_shape)}
1447 # Create a limit for the output dimensions size
1448 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1449
1450 # Currently allow all combinations that are reasonable size
1451 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001452
1453 n = 0
1454 for s in sorted(list(strides)):
1455 for p in sorted(list(paddings)):
1456 for d in sorted(list(dilations)):
1457 if (
1458 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001459 # the padded shape must exceed the dilation * kernel to get a positive
1460 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001461 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
1462 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01001463 and (
1464 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01001465 or (
1466 (ifm_shape[3] - 1 + p[4] + p[5])
1467 > d[2] * (k_shape[2] - 1)
1468 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01001469 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001470 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001471 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01001472 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001473 for index in range(k_rank):
1474 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001475 partial = (
1476 ifm_shape[index + 1]
1477 - 1
1478 + p[pad_offset]
1479 + p[pad_offset + 1]
1480 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001481 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001482 remainders.append(partial % s[index])
1483 outputs.append((partial // s[index]) + 1)
1484
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001485 if (
1486 # the parameters must produce integer exact output
1487 error_name != ErrorIf.ConvOutputShapeNonInteger
1488 and max(remainders) == 0
1489 ) or (
1490 error_name == ErrorIf.ConvOutputShapeNonInteger
1491 and max(remainders) > 0
1492 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001493 if (
1494 max_dim_size is not None
1495 and max(outputs) >= max_dim_size
1496 ):
1497 # Test will consume too much memory - skip it
1498 continue
1499
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001500 # Compliance - number of dot product calculations
1501 if depthwise:
1502 # TODO - add support
1503 dots = 0
1504 else:
1505 dots = gtu.product(
1506 (ifm_shape[0], *outputs, filter_shape[0])
1507 )
1508 args_dict = {
1509 "acc_type": accum_dtype,
1510 "stride": s,
1511 "pad": p,
1512 "dilation": d,
1513 "kernel": k_shape,
1514 "ks": k_size,
1515 "dot_products": dots,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001516 "shape": ifm_shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001517 }
1518
Jeremy Johnson0c716862023-04-13 17:18:19 +01001519 # Support for larger values than 9 needs different delimiter
1520 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001521 arg_list.append(
1522 (
James Ward8b390432022-08-12 20:48:56 +01001523 "acc{}_st{}_pad{}_dilat{}".format(
1524 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001525 delim.join([str(x) for x in s]),
1526 delim.join([str(x) for x in p]),
1527 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001528 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001529 args_dict,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001530 )
1531 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001532 n += 1
1533
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001534 arg_list = TosaArgGen._add_data_generators(
1535 testGen,
1536 opName,
1537 dtypes[0],
1538 arg_list,
1539 error_name,
1540 )
1541 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001542 return arg_list
1543
1544 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001545 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1546
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001547 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001548 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01001549
1550 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001551 accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01001552 elif error_name == ErrorIf.WrongInputType:
1553 # Pick some potentially correct output dtype if input type is incorrect
1554 accum_dtype = DType.INT32
1555 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001556 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001557
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001558 # Set up compliance info
1559 args_dict = {
1560 "acc_type": accum_dtype,
1561 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
1562 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
1563 "shape": shapeList[0],
1564 }
1565
1566 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
1567
1568 arg_list = TosaArgGen._add_data_generators(
1569 testGen,
1570 opName,
1571 input_dtype,
1572 arg_list,
1573 error_name,
1574 )
1575 # Return list of tuples: (arg_str, args_dict)
1576 return arg_list
James Ward8b390432022-08-12 20:48:56 +01001577
1578 @staticmethod
1579 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
1580 # Get valid accumulate type(s)
1581 if dtype == DType.INT8:
1582 accum_dtypes = [DType.INT32]
1583 elif dtype == DType.INT16:
1584 accum_dtypes = [DType.INT48]
1585 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001586 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001587 elif dtype == DType.BF16:
1588 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001589 elif dtype == DType.FP32:
1590 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001591 elif error_name is None:
1592 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
1593
1594 if error_name == ErrorIf.WrongOutputType:
1595 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson1271c442023-09-05 11:39:26 +01001596 accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01001597 elif error_name == ErrorIf.WrongInputType:
1598 # Pick some potentially correct output dtype if input type is incorrect
1599 accum_dtypes = [DType.INT32]
1600
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001601 # Set up compliance info
1602 args_dict = {
1603 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
1604 # Set dot_products = N*H*W
1605 "dot_products": gtu.product(
1606 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
1607 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001608 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001609 }
1610
1611 # Create arg tuple of string and dict
1612 arg_list = []
1613 for a in accum_dtypes:
1614 d = args_dict.copy()
1615 d["acc_type"] = a
1616 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001617
1618 arg_list = TosaArgGen._add_data_generators(
1619 testGen,
1620 opName,
1621 dtype,
1622 arg_list,
1623 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001624 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001625 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001626 return arg_list
James Ward8b390432022-08-12 20:48:56 +01001627
1628 @staticmethod
1629 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001630 arg_list = []
1631
Jeremy Johnson0c716862023-04-13 17:18:19 +01001632 if testGen.args.level8k and error_name is not None:
1633 # Don't produce negative large tests
1634 return arg_list
1635
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001636 ifm_shape = shapeList[0]
1637 filter_shape = shapeList[1]
1638
Jeremy Johnson1271c442023-09-05 11:39:26 +01001639 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001640
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001641 # Must be rank 4
1642 if error_name != ErrorIf.WrongRank:
1643 assert len(ifm_shape) == 4
1644 assert len(filter_shape) == 4
1645
Jeremy Johnson0c716862023-04-13 17:18:19 +01001646 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001647
Jeremy Johnson0c716862023-04-13 17:18:19 +01001648 if not testGen.args.level8k:
1649 # Generate comprehensive argument lists
1650 # - except for named errors, which use specific invalid value(s)
1651 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
1652 if error_name == ErrorIf.PadLargerEqualKernel:
1653 max_filter_size = -max(k_shape[0], k_shape[1])
1654 p_vals = [
1655 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
1656 ]
1657 else:
1658 p_vals = [
1659 x
1660 for x in range(
1661 smallest_padding_size, testGen.args.max_conv_padding + 1
1662 )
1663 ]
1664 paddings = {x for x in itertools.product(*([p_vals] * 4))}
1665 if error_name == ErrorIf.StrideSmallerOne:
1666 # Can't use stride=0, as it is used to derive output shape, as a divisor
1667 s_vals = [testGen.rng.choice(range(-5, 0))]
1668 else:
1669 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
1670 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001671
Jeremy Johnson0c716862023-04-13 17:18:19 +01001672 if not error_name and testGen.args.oversize:
1673 # add some oversize argument values
1674 if max(ifm_shape) < 64:
1675 bigPadding = 9
1676 paddings.update(
1677 {
1678 x
1679 for x in itertools.product(
1680 *([[smallest_padding_size, bigPadding]] * 4)
1681 )
1682 }
1683 )
1684 bigStride = 8
1685 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
1686
1687 # There are too many parameter combinations, so generate them sparsely,
1688 # very sparse for negative tests
1689 sparsity_factor = 2 if error_name else 10
1690 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
1691 # If there are only a small number of tests, just select them all
1692 if sparsity < 13:
1693 sparsity = 1
1694 # To get a variety of parameter combinations sparsity should not be a
1695 # multiple of 2, 3 or 5
1696 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1697 sparsity += 1
1698 else:
1699 # Only test 8k levels boundaries
1700 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1701 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1702 bigPadding = bigKernel
1703
1704 pad_shape = [0] * (len(k_shape) * 2)
1705 stride_shape = [1] * len(k_shape)
1706 # The point at which input dimension combined with the stride will
1707 # create large output sizes!
1708 LARGE_SIZE = 2
1709 for idx in range(len(k_shape)):
1710 pad_offset = idx * 2
1711 if k_shape[idx] == bigKernel:
1712 # Set large stride
1713 stride_shape[idx] = bigKernel
1714 # Use negative output padding to reduce shape size
1715 pad_shape[pad_offset] = -(bigPadding - 1)
1716 if ifm_shape[idx + 1] > LARGE_SIZE:
1717 pad_shape[pad_offset + 1] = -(bigPadding - 1)
1718 else:
1719 # The other dimension should be the bigKernel
1720 alt_idx = 1 - idx
1721 if (
1722 k_shape[alt_idx] == bigKernel
1723 and ifm_shape[alt_idx + 1] < LARGE_SIZE
1724 ):
1725 # As the input is small, the large stride won't
1726 # affect the output so we can add some padding
1727 pad_shape[pad_offset + 1] = bigPadding
1728
1729 strides = {tuple(stride_shape)}
1730 paddings = {tuple(pad_shape)}
1731
1732 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001733 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001734
1735 n = 0
1736 for s in sorted(list(strides)):
1737 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07001738 if n % sparsity == 0:
1739 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001740 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
1741 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07001742 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001743
1744 # Support for larger values than 9 needs different delimiter
1745 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07001746 arg_list.append(
1747 (
James Ward8b390432022-08-12 20:48:56 +01001748 "acc{}_st{}_pad{}_os{}".format(
1749 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001750 delim.join([str(x) for x in s]),
1751 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07001752 "x".join([str(x) for x in os]),
1753 ),
James Ward8b390432022-08-12 20:48:56 +01001754 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001755 )
TatWai Chong24594f52022-06-08 00:48:04 -07001756 )
1757 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001758
1759 return arg_list
1760
1761 @staticmethod
1762 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001763 rank = len(shapeList[0])
1764
1765 # Exhaustively test combinations of padding on each side of each dimension
1766 # - the range of padding values is defined by pad_min and pad_max
1767 # - for padding >9, the name format needs to be more distinctive
1768 pad_min, pad_max = 0, 1
1769 pad_values = [x for x in range(pad_min, pad_max + 1)]
1770 if error_name == ErrorIf.PadSmallerZero:
1771 pad_values = [x for x in range(-2, 0)]
1772 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
1773 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
1774
1775 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
1776 pad_const_int = testGen.getRandNumberDType(dtype)
1777 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00001778 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001779 pad_const_int = 0
1780 pad_const_fp = testGen.getRandNumberDType(dtype)
1781 else:
1782 return []
1783
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001784 list_shape_pad_values = list(shape_pad_values)
1785 # If we are producing tests for rank 6 or greater use sparsity
1786 if len(list_shape_pad_values) > 1024:
1787 sparsity_factor = 2 if error_name else 120
1788 sparsity = TosaArgGen._calculate_sparsity(
1789 len(list_shape_pad_values), sparsity_factor
1790 )
1791 else:
1792 sparsity = 1
1793
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001794 # Build arg list
1795 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001796 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01001797 paddings = list(paddings)
1798 args_valid = True
1799
1800 if error_name == ErrorIf.PadSmallerZero:
1801 # Prevent negative output shapes while ensuring still testing for negative padding
1802 for i in range(rank):
1803 dim_after_padding = (
1804 paddings[i][0] + paddings[i][1] + shapeList[0][i]
1805 )
1806 if dim_after_padding < 1:
1807 paddings[i] = (0, 0)
1808 if all([p > -1 for p in paddings[i]]):
1809 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001810 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01001811 name = "pad"
1812 for r in range(rank):
1813 before, after = paddings[r]
1814 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001815 args_dict = {
1816 "pad": np.array(paddings),
1817 "pad_const_int": pad_const_int,
1818 "pad_const_fp": pad_const_fp,
1819 }
1820 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01001821
1822 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
1823 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001824
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001825 arg_list = TosaArgGen._add_data_generators(
1826 testGen,
1827 opName,
1828 dtype,
1829 arg_list,
1830 error_name,
1831 )
1832
1833 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001834 return arg_list
1835
1836 @staticmethod
1837 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
1838 arg_list = []
1839
1840 shape = shapeList[0]
1841 if error_name != ErrorIf.WrongRank:
1842 assert len(shape) == 4
1843
Jeremy Johnson0c716862023-04-13 17:18:19 +01001844 test_level8k = testGen.args.level8k and error_name is None
1845
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001846 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001847 startKernel = 2
1848 startPad = 0
1849 if not test_level8k:
1850 # Generate comprehensive argument lists
1851 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
1852 paddings = {x for x in itertools.product(*([p_vals] * 4))}
1853 # Stride must be greater than 1 to force non-integer error
1854 s_vals = [
1855 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
1856 ]
1857 strides = {x for x in itertools.product(*([s_vals] * 2))}
1858 k_vals = [
1859 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
1860 ]
1861 kernels = {x for x in itertools.product(*([k_vals] * 2))}
1862 max_dim_size = None
1863 else:
1864 # Only test 8k levels
1865 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1866 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1867 strides = {(1, bigStride), (bigStride, 4)}
1868 kernels = {(1, bigKernel), (bigKernel, 3)}
1869 paddings = set()
1870 for s in sorted(list(strides)):
1871 for k in sorted(list(kernels)):
1872 padding = []
1873 for idx in range(len(k)):
1874 total_padding = s[idx] - shape[idx + 1] + k[idx]
1875 while total_padding < 0:
1876 # Must meet: shape + padding > kernel
1877 total_padding += s[idx]
1878 if total_padding < k[idx]:
1879 padding.extend([0, total_padding])
1880 else:
1881 # Note this may produce padding >= k[idx] which is not
1882 # allowed - but will be ignored in the creation loop below
1883 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
1884 paddings.add(tuple(padding))
1885 # Create a limit for the output dimensions size
1886 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001887
James Ward8b390432022-08-12 20:48:56 +01001888 if opName == "max_pool2d":
1889 accum_dtypes = [None] # max_pool has no accumulate dtype
1890 elif dtype == DType.INT8 or dtype == DType.INT16:
1891 accum_dtypes = [DType.INT32]
1892 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001893 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001894 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001895 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001896 elif error_name is None:
1897 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
1898 else:
1899 # Set to something for the ErrorIf case which has
1900 # incorrect input data-type
1901 accum_dtypes = [DType.INT32]
1902
Jeremy Johnson0c716862023-04-13 17:18:19 +01001903 if not test_level8k:
1904 if testGen.args.oversize:
1905 # add some oversize argument values
1906 bigStride = 7
1907 bigKernel = 9
1908 strides.update(
1909 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001910 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001911 kernels.update(
1912 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
1913 )
1914 if max(shape) < 64:
1915 # padding must be less than the kernel size
1916 bigPadding = bigKernel - 1
1917 paddings.update(
1918 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
1919 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001920
Jeremy Johnson0c716862023-04-13 17:18:19 +01001921 # There are too many parameter combinations, so generate them sparsely,
1922 # very sparse for negative tests
1923 sparsity_factor = 2 if error_name else 500
1924 sparsity = (
1925 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
1926 )
1927 else:
1928 # We have already limited test output combinations for 8k tests
1929 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001930
James Ward8b390432022-08-12 20:48:56 +01001931 arg_str = (
1932 "acc{}_st{}_kern{}_pad{}"
1933 if accum_dtypes[0] is not None
1934 else "st{}_kern{}_pad{}"
1935 )
1936
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001937 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01001938 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001939 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01001940
1941 # Support for larger values than 9 needs different delimiter
1942 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01001943 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01001944 delim.join([str(x) for x in stride]),
1945 delim.join([str(x) for x in kern]),
1946 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01001947 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001948 args_dict = {
1949 "stride": stride,
1950 "pad": pad,
1951 "kernel": kern,
1952 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001953 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001954 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
1955 }
James Ward8b390432022-08-12 20:48:56 +01001956
1957 if accum is not None:
1958 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001959 args_dict["acc_type"] = accum
1960 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01001961
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001962 n = 0
James Ward8b390432022-08-12 20:48:56 +01001963 for a in accum_dtypes:
1964 for s in sorted(list(strides)):
1965 for p in sorted(list(paddings)):
1966 for k in sorted(list(kernels)):
1967 if error_name in [
1968 ErrorIf.StrideSmallerOne,
1969 ErrorIf.KernelSmallerOne,
1970 ErrorIf.PadSmallerZero,
1971 ErrorIf.PadLargerEqualKernel,
1972 ]:
1973 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
1974 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001975 )
James Ward8b390432022-08-12 20:48:56 +01001976 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001977 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001978 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001979 )
James Ward8b390432022-08-12 20:48:56 +01001980 elif (
1981 n % sparsity == 0
1982 # padding must not exceed the kernel size
1983 and p[0] < k[0]
1984 and p[1] < k[0]
1985 and p[2] < k[1]
1986 and p[3] < k[1]
1987 # the padded shape must exceed the kernel size
1988 and (shape[1] + p[0] + p[1]) > k[0]
1989 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001990 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001991 partial_h = shape[1] + p[0] + p[1] - k[0]
1992 partial_w = shape[2] + p[2] + p[3] - k[1]
1993 remainder_h = partial_h % s[0]
1994 remainder_w = partial_w % s[1]
1995 output_h = partial_h // s[0] + 1
1996 output_w = partial_w // s[1] + 1
1997 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01001998 if (
1999 # the parameters must produce integer exact output
2000 error_name != ErrorIf.PoolingOutputShapeNonInteger
2001 and remainder_h == 0
2002 and remainder_w == 0
2003 ) or (
2004 error_name == ErrorIf.PoolingOutputShapeNonInteger
2005 and (remainder_h != 0 or remainder_w != 0)
2006 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002007 if (
2008 max_dim_size is not None
2009 and max(output_h, output_w) > max_dim_size
2010 ):
2011 # Test will consume too much memory - skip it
2012 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002013 # Dot products = N*OH*OW*C
2014 dp = gtu.product(
2015 (shape[0], output_h, output_w, shape[3])
2016 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002017 arg_list.append(
2018 get_arg_list_element(a, s, p, k, dp, shape)
2019 )
James Ward8b390432022-08-12 20:48:56 +01002020 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002021
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002022 # Now add data generator types
2023 arg_list = TosaArgGen._add_data_generators(
2024 testGen,
2025 opName,
2026 dtype,
2027 arg_list,
2028 error_name,
2029 )
2030
2031 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002032 return arg_list
2033
2034 @staticmethod
2035 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
2036 arg_list = []
2037
2038 # Enumerate the output types here
2039 if error_name == ErrorIf.WrongOutputType:
2040 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
2041 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002042 dtypeList = [
2043 DType.BOOL,
2044 DType.INT16,
2045 DType.INT32,
2046 DType.FP16,
2047 DType.BF16,
2048 DType.FP32,
2049 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002050 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002051 dtypeList = [
2052 DType.BOOL,
2053 DType.INT8,
2054 DType.INT32,
2055 DType.FP16,
2056 DType.BF16,
2057 DType.FP32,
2058 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002059 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002060 dtypeList = [
2061 DType.BOOL,
2062 DType.INT8,
2063 DType.INT16,
2064 DType.FP16,
2065 DType.BF16,
2066 DType.FP32,
2067 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002068 elif inDtype == DType.BOOL:
2069 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002070 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00002071 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002072 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00002073 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002074 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00002075 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002076 elif error_name == ErrorIf.WrongInputType:
2077 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002078 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002079 else:
2080 raise Exception("Unexpected input dtype: {}".format(inDtype))
2081
2082 for dtype in dtypeList:
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002083 arg_list.append(("out{}".format(testGen.typeStr(dtype)), [dtype]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002084
2085 return arg_list
2086
2087 @staticmethod
2088 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
2089 arg_list = []
2090
2091 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002092 for outDtype in [
2093 DType.UINT8,
2094 DType.INT8,
2095 DType.INT16,
2096 DType.INT32,
2097 DType.UINT16,
2098 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002099 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002100 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002101 and error_name == ErrorIf.OutputZeroPointNotZero
2102 ):
2103 continue
2104 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002105 outDtype != DType.UINT16
2106 and error_name == ErrorIf.U16OutputZeroPointNotValid
2107 ) or (
2108 inDtype != DType.UINT16
2109 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002110 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002111 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002112 continue
2113 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002114 inDtype == DType.UINT8
2115 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002116 and error_name != ErrorIf.WrongOutputType
2117 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002118 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2119 continue
2120 if (
2121 inDtype not in [DType.INT8, DType.INT16]
2122 and outDtype == DType.UINT8
2123 and error_name != ErrorIf.WrongOutputType
2124 ):
2125 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2126 continue
2127 if (
2128 inDtype == DType.UINT16
2129 and outDtype != DType.INT16
2130 and error_name != ErrorIf.WrongOutputType
2131 ):
2132 # The only output dtype for UINT16 is INT16, skip all others
2133 continue
2134 if (
2135 inDtype != DType.INT16
2136 and outDtype == DType.UINT16
2137 and error_name != ErrorIf.WrongOutputType
2138 ):
2139 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002140 continue
2141 if (
2142 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002143 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002144 ):
2145 continue
2146
2147 for scale32 in [False, True]:
2148 if error_name == ErrorIf.ScaleTrue and not scale32:
2149 continue
2150 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2151 continue
2152 for double_round in [False, True]:
2153 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2154 continue
2155 for per_channel in [False, True]:
2156
2157 if (
2158 inDtype == DType.INT48
2159 and scale32
2160 and error_name != ErrorIf.ScaleTrue
2161 ):
2162 # Illegal condition. Must be scale32=False
2163 continue
2164 if (
2165 double_round
2166 and not scale32
2167 and error_name != ErrorIf.ScaleNotTrue
2168 ):
2169 # Illegal condition. ERROR_IF(!scale32 && double_round)
2170 continue
2171
2172 arg_list.append(
2173 (
2174 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002175 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002176 int(scale32),
2177 int(double_round),
2178 int(per_channel),
2179 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002180 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002181 )
2182 )
2183
2184 return arg_list
2185
2186 @staticmethod
2187 def agMul(testGen, opName, shapeList, dtype, error_name=None):
2188 arg_list = []
2189
2190 if dtype is DType.INT32:
2191 for p in range(testGen.args.num_rand_permutations):
2192
2193 shift = testGen.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002194 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002195 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002196 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002197
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002198 arg_list = TosaArgGen._add_data_generators(
2199 testGen,
2200 opName,
2201 dtype,
2202 arg_list,
2203 error_name,
2204 )
2205 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002206 return arg_list
2207
2208 @staticmethod
2209 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
2210 arg_list = []
2211
2212 arg_list.append(("roundTrue", [True]))
2213 arg_list.append(("roundFalse", [False]))
2214
2215 return arg_list
2216
Luke Hutton57287132023-02-06 14:54:18 +00002217 @staticmethod
2218 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2219 arg_list = []
2220
2221 arg_list.append(("inverseTrue", [True]))
2222 arg_list.append(("inverseFalse", [False]))
2223
2224 return arg_list
2225
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002226 # Helper function for reshape. Gets some factors of a larger number.
2227 @staticmethod
2228 def getFactors(val, start=1):
2229 factors = []
2230
2231 for i in range(start, int(np.sqrt(val)) + 1):
2232 if (val % i) == 0:
2233 factors.append(i)
2234
2235 return factors
2236
2237 @staticmethod
2238 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
2239 arg_list = []
2240
2241 origShape = shapeList[0]
2242
2243 totalElements = 1
2244 for s in origShape:
2245 totalElements *= s
2246
2247 # This code is NOT fast. Fortunately, the numbers are fairly small.
2248 factors = TosaArgGen.getFactors(totalElements)
2249
2250 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002251 # Rank from 1 to TOSA_TENSOR_MAX_RANK
2252 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002253 if len(factors) < newRank:
2254 continue
2255
2256 found = True
2257 # escape_counter breaks while loop if it continues on for too long
2258 escape_counter = 0
2259 while found:
2260 newShape = []
Jerry Ge264f7fa2023-04-21 22:49:57 +00002261 new_shape_inferred = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002262 # Generate newShape ensuring it isn't a duplicate
2263 remainingElements = totalElements
2264 shuffledFactors = testGen.rng.permutation(factors)
Jerry Ge264f7fa2023-04-21 22:49:57 +00002265 inferred_dim = testGen.rng.integers(1, newRank + 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002266 for i in range(1, newRank):
2267 # pick rank-1 factors
2268 newShape.append(shuffledFactors[0])
2269 remainingElements = remainingElements // shuffledFactors[0]
Jerry Ge264f7fa2023-04-21 22:49:57 +00002270 if i == inferred_dim:
2271 new_shape_inferred.append(-1)
2272 else:
2273 new_shape_inferred.append(shuffledFactors[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002274 shuffledFactors = testGen.rng.permutation(
2275 TosaArgGen.getFactors(remainingElements)
2276 )
2277 newShape.append(remainingElements)
Jerry Ge264f7fa2023-04-21 22:49:57 +00002278 if inferred_dim == newRank:
2279 new_shape_inferred.append(-1)
2280 else:
2281 new_shape_inferred.append(remainingElements)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002282
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002283 # Check for duplicates
2284 found = False
2285 for name, other_shape in arg_list:
2286 if other_shape[0] == newShape:
2287 found = True
2288 break
2289
2290 escape_counter += 1
2291 if escape_counter >= 100:
2292 break
2293
2294 if not found:
Jerry Ge264f7fa2023-04-21 22:49:57 +00002295 if error_name in [
2296 ErrorIf.ReshapeOutputSizeNonInteger,
2297 ErrorIf.ReshapeOutputSizeMultiInference,
2298 ]:
2299 if newRank < 2:
2300 # Need at least two dimensions
2301 continue
2302 # NOTE: Change inferred_dim starting offset from 1 to 0
2303 inferred_dim -= 1
2304 extra_dim = inferred_dim + testGen.rng.integers(1, newRank)
2305 extra_dim = extra_dim % newRank
2306 assert extra_dim != inferred_dim
2307 if error_name == ErrorIf.ReshapeOutputSizeNonInteger:
2308 elements = 1
2309 for i, dim_value in enumerate(new_shape_inferred):
2310 if i != inferred_dim and i != extra_dim:
2311 elements *= dim_value
2312 dim_value = new_shape_inferred[extra_dim]
2313 while totalElements % (elements * dim_value) == 0:
2314 dim_value += 1
2315 new_shape_inferred[extra_dim] = dim_value
2316 else:
2317 assert error_name == ErrorIf.ReshapeOutputSizeMultiInference
2318 new_shape_inferred[extra_dim] = -1
2319 else:
2320 arg_list.append(
2321 ("perm{}_rank{}_outdefined".format(p, newRank), [newShape])
2322 )
2323 if error_name != ErrorIf.TensorSizeInputOutputMismatch:
2324 arg_list.append(
2325 (
2326 "perm{}_rank{}_outinferred".format(p, newRank),
2327 [new_shape_inferred],
2328 )
2329 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002330
2331 return arg_list
2332
2333 @staticmethod
2334 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
2335 arg_list = []
2336
2337 ifm_shape = shapeList[0]
2338
2339 if error_name == ErrorIf.IndexOutsideBounds:
2340 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
2341 incorrect_small_index = range(-len(ifm_shape), 0)
2342 permutations = [p for p in itertools.permutations(incorrect_large_index)]
2343 permutations.extend(
2344 [p for p in itertools.permutations(incorrect_small_index)]
2345 )
2346 elif error_name == ErrorIf.IndexUsedTwice:
2347 # Create list with a duplicated index
2348 perm_range = list(range(len(ifm_shape)))
2349 index_choice = testGen.rng.choice(range(len(perm_range)))
2350 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
2351 permutations = [p for p in itertools.permutations(perm_range)]
2352
2353 else:
2354 # Get all permutations
2355 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
2356
2357 # Limit to possible permutations from shape dimension or argument setting
2358 limit = min(len(permutations), testGen.args.num_rand_permutations)
2359
2360 # Get random permutation generator that uses all permutations
2361 random_permutations = testGen.rng.permutation(permutations)
2362
2363 # Create list of required amount of permutations
2364 arg_list = [
2365 ("perm{}".format(p), [random_permutations[p].tolist()])
2366 for p in range(limit)
2367 ]
2368 return arg_list
2369
2370 @staticmethod
2371 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
2372 arg_list = []
2373
2374 ifm_shape = shapeList[0]
2375 rank = len(ifm_shape)
2376
2377 for p in range(testGen.args.num_rand_permutations):
2378 start = []
2379 size = []
2380
2381 valid = True
2382
2383 for i in range(rank):
2384 if ifm_shape[i] > 1:
2385 start.append(testGen.randInt(0, ifm_shape[i]))
2386 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
2387
2388 # Invalid slice size?
2389 if size[i] == 0:
2390 valid = False
2391 else:
2392 start.append(0)
2393 size.append(1)
2394
2395 if valid:
2396 # If ERROR_IF test required then incorrect start, size will be returned
2397 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
2398 testGen, error_name, ifm_shape, start, size
2399 )
2400 arg_list.append(("perm{}".format(p), [start, size]))
2401 return arg_list
2402
2403 @staticmethod
2404 def agTile(testGen, opName, shapeList, dtype, error_name=None):
2405 arg_list = []
2406
2407 ifm_shape = shapeList[0]
2408 rank = len(ifm_shape)
2409
2410 for p in range(testGen.args.num_rand_permutations):
2411
2412 # Pick a few random, but small multiple values
2413 # because otherwise this has a tendency to generate
2414 # enormous tensors
2415 multiples = []
2416 for i in range(rank):
2417 if ifm_shape[i] > 1000:
2418 # Multiple of 1 if ifm_shape dimension is large to reduce
2419 # tensor size
2420 multiples.append(1)
2421 elif max(ifm_shape) > 1000:
2422 multiples.append(2)
2423 else:
2424 multiples.append(testGen.randInt(1, 4))
2425 arg_list.append(("perm{}".format(p), [multiples]))
2426
2427 return arg_list
2428
2429 @staticmethod
2430 def agResize(testGen, opName, shapeList, dtype, error_name=None):
2431 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002432 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002433
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002434 def get_aspect_ratio_resize_params():
2435 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
2436 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
2437 invert = testGen.rng.choice((False, True))
2438 letterbox = testGen.rng.choice((False, True))
2439
2440 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
2441 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
2442 scale_y_d = scale_x_d = 1
2443 offset_x = offset_y = 0
2444
2445 if letterbox:
2446 max_border = scale_y_n
2447 border_y = testGen.randInt(low=0, high=max_border)
2448 border_x = 0
2449 else:
2450 # Pillarboxing
2451 border_y = 0
2452 max_border = scale_x_n
2453 border_x = testGen.randInt(low=0, high=max_border)
2454
2455 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2456 offset = (offset_y, offset_x)
2457 border = (border_y, border_x)
2458
2459 return scale, offset, border
2460
2461 def get_upscale_downscale_params():
2462 valid_params = False
2463 while not valid_params:
2464 upscale = testGen.rng.choice((False, True))
2465
2466 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
2467 origin_sampling = testGen.rng.choice((False, True))
2468
2469 if upscale:
2470 shift = testGen.randInt(low=1, high=4)
2471 scale_x_d = scale_y_d = 1
2472 scale_x_n = scale_y_n = (
2473 1 << shift if origin_sampling else 2 << shift
2474 )
2475 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
2476 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
2477 else:
2478 scale_x_n = 1
2479 scale_y_n = 1
2480
2481 # Return list of valid scale_*_d values (max value 4) given input dim shape
2482 def get_valid_denom(ifm_dim):
2483 return [x for x in range(1, 5) if ifm_dim % x == 1]
2484
2485 # Generate list of valid downscale values and choose one randomly
2486 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
2487 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
2488
2489 if not valid_scale_y_ds and not valid_scale_x_ds:
2490 # Bad parameters, skip
2491 continue
2492
2493 if not valid_scale_y_ds:
2494 scale_y_d = 1
2495 else:
2496 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
2497
2498 if not valid_scale_x_ds:
2499 scale_x_d = 1
2500 else:
2501 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
2502
2503 border_x = border_y = 0
2504 offset_y = testGen.randInt(0, 16 * scale_y_n)
2505 offset_x = testGen.randInt(0, 16 * scale_x_n)
2506 valid_params = True
2507
2508 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2509 offset = (offset_y, offset_x)
2510 border = (border_y, border_x)
2511 return scale, offset, border
2512
2513 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002514 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
2515 scale = scale_n / scale_d
2516 if scale > max_scale:
2517 factor = scale / max_scale
2518 new_scale_d = math.ceil(scale_d * factor)
2519 assert scale_n / new_scale_d <= max_scale
2520 scale_d = new_scale_d
2521 return scale_d
2522
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002523 # Scale
2524 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
2525 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
2526
2527 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
2528 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
2529
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002530 scale_y_d = fix_scale_to_max_scale(
2531 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2532 )
2533 scale_x_d = fix_scale_to_max_scale(
2534 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2535 )
2536
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002537 # Offsets and border within the scale
2538 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
2539 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
2540 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
2541 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
2542
2543 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2544 offset = (offset_y, offset_x)
2545 border = (border_y, border_x)
2546 return scale, offset, border
2547
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002548 def get_level_8k_params():
2549 # Create 64x scale - 64/1 to 2048/32
2550 scale_d = testGen.randInt(
2551 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
2552 )
2553 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
2554 # Create half to fifth scaling
2555 scale_d_alt = testGen.randInt(low=2, high=6)
2556 scale_n_alt = 1
2557 switch = testGen.rng.choice((False, True))
2558 if switch:
2559 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
2560 else:
2561 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
2562
2563 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
2564 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
2565 offset = (offset_y, offset_x)
2566 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
2567 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
2568 border = (border_y, border_x)
2569 return scale, offset, border
2570
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002571 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002572 # Exclude illegal {mode, type} configurations. Pick legal output types
2573 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
2574 outputDTypeList = [DType.INT8]
2575 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
2576 outputDTypeList = [DType.INT16]
2577 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
2578 outputDTypeList = [DType.INT32]
2579 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
2580 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01002581 elif dtype == DType.FP16:
2582 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01002583 elif dtype == DType.BF16:
2584 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002585 elif dtype == DType.FP32:
2586 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002587 elif error_name == ErrorIf.WrongInputType:
2588 # If an incorrect input type is used then we set a 'correct'
2589 # output type to avoid other errors
2590 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
2591 else:
2592 continue
2593
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002594 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
2595
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002596 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002597 perm = 0
2598 while perm < testGen.args.num_rand_permutations:
2599 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002600 if not testGen.args.level8k:
2601 _rnd_param_fn = testGen.rng.choice(
2602 (
2603 get_rand_params,
2604 get_upscale_downscale_params,
2605 get_aspect_ratio_resize_params,
2606 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002607 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002608 scale, offset, border = _rnd_param_fn()
2609 else:
2610 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002611
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002612 # Expand params for bounds-checking
2613 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
2614 (offset_y, offset_x) = offset
2615 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002616
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002617 # Make sure output dimensions OH and OW are integers
2618 partial_output_y = (
2619 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
2620 )
2621 partial_output_x = (
2622 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
2623 )
2624 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002625 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002626 if (
2627 partial_output_y % scale_y_d == 0
2628 and partial_output_x % scale_x_d == 0
2629 ):
2630 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002631 if perm > 0:
2632 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002633 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002634 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002635 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002636 while partial_output_y % scale_y_d != 0:
2637 scale_y_d -= 1
2638 while partial_output_x % scale_x_d != 0:
2639 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002640 # Make sure we are still within max scaling
2641 if (
2642 scale_y_n / scale_y_d
2643 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
2644 scale_x_n / scale_x_d
2645 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
2646 # Skip the test as it is using too large a scaling factor
2647 if perm > 0:
2648 perm += 1
2649 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002650
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002651 output_y = partial_output_y // scale_y_d + 1
2652 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002653
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002654 if (
2655 output_y >= testGen.args.max_resize_output_dim
2656 or output_x >= testGen.args.max_resize_output_dim
2657 ) and error_name is None:
2658 # Skip positive test if output dim will be too high
2659 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002660 if not testGen.args.level8k or perm > 0:
2661 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002662 continue
2663
2664 if (
2665 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01002666 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002667 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01002668 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002669 ):
2670 # Output dimensions out of scope
2671 if error_name is not None and perm > 0:
2672 # As long as we have one ERROR_IF test, don't worry
2673 # about creating all the other permutations
2674 perm += 1
2675 continue
2676
2677 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
2678 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01002679 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002680 and output_y - scale_y_d < 1
2681 )
2682 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01002683 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002684 and output_x - scale_x_d < 1
2685 )
2686 ):
2687 # Can't create a negative test with these params as it
2688 # will create invalid output size
2689 if perm > 0:
2690 perm += 1
2691 continue
2692
2693 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
2694 offset = [offset_y, offset_x]
2695 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002696
2697 # Common for all data types
2698 if error_name is not None:
2699 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002700 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002701 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002702 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002703 outputDTypeNew,
2704 ) = TosaErrorIfArgGen.eiResizeErrorIf(
2705 testGen,
2706 error_name,
2707 mode,
2708 dtype,
2709 shapeList,
2710 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002711 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002712 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002713 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002714 )
2715 else:
2716 outputDTypeNew = outputDType
2717
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002718 arg_to_append = (
2719 arg_str.format(
2720 "N" if mode == ResizeMode.NEAREST else "B",
2721 testGen.typeStr(outputDTypeNew),
2722 scale[0],
2723 scale[1],
2724 scale[2],
2725 scale[3],
2726 offset[0],
2727 offset[1],
2728 border[0],
2729 border[1],
2730 ),
2731 [
2732 mode,
2733 scale,
2734 offset,
2735 border,
2736 dtype,
2737 outputDTypeNew,
2738 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002739 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002740 if arg_to_append in arg_list:
2741 # Skip already generated test params
2742 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002743
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002744 # Valid permutation
2745 perm += 1
2746 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002747 return arg_list
2748
2749 @staticmethod
2750 def agTable(testGen, opName, shapeList, dtype, error_name=None):
2751 arg_list = []
2752
2753 if dtype == DType.INT8:
2754 table = np.int32(
2755 testGen.rng.integers(low=-128, high=128, size=[256])
2756 ).tolist()
2757 else: # INT16
2758 table = np.int32(
2759 testGen.rng.integers(low=-32768, high=32768, size=[513])
2760 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07002761 # Make sure all slopes are within REQUIRE min/max 16-bit int
2762 for idx in range(len(table) - 1):
2763 slope = table[idx + 1] - table[idx]
2764 # Alter the next table entry to force the slope to be ok
2765 if slope > 32767:
2766 table[idx + 1] -= slope - 32767
2767 if slope < -32768:
2768 table[idx + 1] -= slope + 32768
2769 slope = table[idx + 1] - table[idx]
2770 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002771 arg_list.append(
2772 (
2773 "",
2774 [table],
2775 )
2776 )
2777 return arg_list
2778
2779 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
2780 # CondIf generates the condition values here.
2781 # Convert to tensors in the build function, along with the
2782 # then and else blocks
2783 arg_list = []
2784
2785 for c in [False, True]:
2786 arg_list.append(("cond{}".format(int(c)), [c]))
2787
2788 return arg_list
2789
2790 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
2791 # While loop: 0 iterations, 1, more than 1
2792 arg_list = []
2793
2794 for iter in [0, 1, 4]:
2795 arg_list.append(("iter{}".format(iter), [iter]))
2796
2797 return arg_list