blob: 6675025cd1c6df561b6381a2df93a878249e7fdd [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
19
20class TosaQuantGen:
21 """QuantizedInfo random generator helper functions.
22
23 Specify with 'qgen': in the operator defintion.
24 """
25
26 def __init__(self):
27 pass
28
29 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010031
32 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010033 if testGen.args.zeropoint is not None:
34 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035 return testGen.randInt(-128, 128)
36 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010037 if testGen.args.zeropoint is not None:
38 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 return testGen.randInt(0, 256)
40 elif error_name in [
41 ErrorIf.InputZeroPointNotZero,
42 ErrorIf.WeightZeroPointNotZero,
43 ErrorIf.OutputZeroPointNotZero,
44 ]:
45 zero_point = testGen.randInt(-128, 128)
46 if zero_point == 0:
47 zero_point = 1
48 return zero_point
49 return 0
50
51 @staticmethod
52 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010053 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000054 qinfo = [
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 TosaQuantGen.getZeroPoint(testGen, dtype),
57 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010058 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000059 qinfo = [
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
62 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010063 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000064 qinfo = [
65 TosaQuantGen.getZeroPoint(testGen, dtype),
66 TosaQuantGen.getZeroPoint(testGen, dtype),
67 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068 return qinfo
69
70 @staticmethod
71 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010072 if isinstance(dtype_or_dtypeList, list):
73 # a list of [input, weights, accumulator] dtypes
74 dtypeList = dtype_or_dtypeList
75 else:
76 # an int, [input, weights, accumulator] dtypes are the same
77 dtypeList = [dtype_or_dtypeList] * 3
78
79 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000080 qinfo = [
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
82 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
83 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000085 qinfo = [
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
87 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
88 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000090 qinfo = [
91 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
92 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
93 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 return qinfo
95
96 @staticmethod
97 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000099 qinfo = [
100 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
101 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
102 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000104 qinfo = [
105 TosaQuantGen.getZeroPoint(testGen, dtype),
106 TosaQuantGen.getZeroPoint(testGen, dtype),
107 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108 return qinfo
109
110 @staticmethod
111 def computeMultiplierAndShift(scaleFp, scale32):
112 # Derived from computeMultiplierAndShiftTosaScale32
113 # Provide a floating-point scaling factor and the scale32 parameter
114 # to compute the multiplier and shift
115
116 if scale32:
117 scaleBits = 31
118 else:
119 scaleBits = 15
120
121 m, shift = math.frexp(scaleFp)
122
123 if scaleFp < 0.0:
124 m = -m
125
126 multiplier = round(m * (1 << scaleBits))
127 assert multiplier <= (1 << scaleBits)
128
129 if multiplier == (1 << scaleBits):
130 multiplier = multiplier // 2
131 shift = shift + 1
132
133 shift = (-shift) + scaleBits
134 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
135 # scaleFp, scaleBits, m, multiplier, shift))
136
137 # Adjust multiplier such that shift is in allowed value range.
138 if shift == 0:
139 multiplier = multiplier // 4
140 shift = shift + 2
141 elif shift == 1:
142 multiplier = multiplier // 2
143 shift = shift + 1
144 elif shift == 63:
145 multiplier = multiplier * 2
146 shift = shift - 1
147
148 assert multiplier <= (1 << scaleBits)
149 assert shift >= 2 and shift <= 62
150
151 return multiplier, shift
152
153
154class TosaTensorGen:
155 """Tensor generators create a shape list for the placeholder and const tensor
156 data operands for the operator.
157
158 The actual random data is generated separately for each test.
159 """
160
161 def __init__(self):
162 pass
163
164 @staticmethod
165 def tgBasic(testGen, opName, rank, error_name=None):
166 pl, const = opName["operands"]
167 shape = testGen.makeShape(rank)
168
169 # Constrict the overall size of the shape when creating ERROR_IF tests
170 if error_name:
171 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
172
173 shape_list = []
174 for i in range(pl + const):
175 shape_list.append(shape.copy())
176
Luke Huttona4e48ca2023-02-22 11:53:48 +0000177 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 if error_name == ErrorIf.RankMismatch:
179 if rank == 1 and i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
181 elif i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
183
184 return shape_list
185
186 @staticmethod
187 def tgNHWC(testGen, opName, rank, error_name=None):
188 pl, const = opName["operands"]
189
190 if error_name != ErrorIf.WrongRank:
191 assert rank == 4
192
193 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000194 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name and error_name != ErrorIf.MaxDimExceeded:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
207 def tgScatter(testGen, opName, rank, error_name=None):
208 pl, const = opName["operands"]
209
210 assert pl == 2
211 assert const == 0
212 if error_name != ErrorIf.WrongRank:
213 assert rank == 3
214
215 values_in_shape = testGen.makeShape(rank)
216
217 # ignore max batch size if target shape is set
218 if testGen.args.max_batch_size and not testGen.args.target_shapes:
James Ward30124a82023-02-02 14:56:33 +0000219 values_in_shape[0] = min(values_in_shape[0], testGen.args.max_batch_size)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100220
221 W = testGen.randInt(
222 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
223 )
224 # Constrict W if one dimension is too large to keep tensor size reasonable
225 if max(values_in_shape) > 5000:
226 W = testGen.randInt(0, 16)
227
228 input_shape = [values_in_shape[0], W, values_in_shape[2]]
229
230 shape_list = []
231 shape_list.append(values_in_shape.copy())
232 shape_list.append(input_shape.copy())
233
234 return shape_list
235
236 @staticmethod
237 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
238 shape = testGen.makeShape(rank)
239
240 pl, const = op["operands"]
241
242 shape_list = []
243
244 # Choose one of the inputs to broadcast
245 # Note: Simplifies OutputShaper code if we don't change first shape for errors
246 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Jerry Ge135c9552023-05-23 20:59:32 +0000247 fuzz_idx = testGen.randInt(0, rank)
248
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100249 for i in range(pl + const):
250 shape_bcast = shape.copy()
251
Jerry Ge135c9552023-05-23 20:59:32 +0000252 # To test broadcasting, the chosen fuzz index dimension should not be 1
253 if shape_bcast[fuzz_idx] == 1:
254 shape_bcast[fuzz_idx] += 1
255
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100256 # If the chosen input, pick a random index to broadcast
257 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000258 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100259 # Add one rank to the shape (or more for rank of 1)
260 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
261 shape_bcast = np.concatenate(
262 (shape_bcast, testGen.makeShape(extra_ranks))
263 )
264 if rank != 1:
265 # Either keep the extra rank, or remove it
266 new_len = testGen.rng.choice([-2, len(shape_bcast)])
267 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000268 elif error_name == ErrorIf.BroadcastShapesMismatch:
269 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100270 else:
271 shape_bcast[fuzz_idx] = 1
272
273 shape_list.append(shape_bcast)
274
275 return shape_list
276
277 @staticmethod
278 def tgConv2D(testGen, op, rank, error_name=None):
279 pl, const = op["operands"]
280
281 if error_name != ErrorIf.WrongRank:
282 assert rank == 4
283
284 # IFM dimensions are NHWC
285 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000286 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100287
288 # Constrict the overall size of the shape when creating ERROR_IF tests
289 if error_name:
290 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
291 ifm_shape, max_dim=24, max_items=10000
292 )
293
294 # Get the filter height/width from the operator parameters
295 filter_hw = op["filter"]
296
297 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000298 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100299
300 # The filter dimensions are OHWI
301 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
302
303 # The bias is OC
304 bias_shape = np.asarray([ofm_depth])
305
306 return [ifm_shape, filter_shape, bias_shape]
307
308 @staticmethod
309 def tgConv3D(testGen, op, rank, error_name=None):
310 pl, const = op["operands"]
311
312 if error_name != ErrorIf.WrongRank:
313 assert rank == 5
314
315 # IFM dimensions are NDHWC
316 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000317 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100318
319 # Constrict the overall size of the shape when creating ERROR_IF tests
320 if error_name:
321 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
322 ifm_shape, max_dim=24, max_items=10000
323 )
324
325 # Get the filter depth/height/width from the operator parameters
326 filter_dhw = op["filter"]
327
328 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000329 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100330
331 # The filter dimensions are ODHWI
332 filter_shape = np.asarray(
333 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
334 )
335
336 # The bias is OC
337 bias_shape = np.asarray([ofm_channel])
338
339 return [ifm_shape, filter_shape, bias_shape]
340
341 @staticmethod
342 def tgTransposeConv2D(testGen, op, rank, error_name=None):
343 pl, const = op["operands"]
344
345 if error_name != ErrorIf.WrongRank:
346 assert rank == 4
347
348 # IFM dimensions are NHWC
349 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000350 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100351
352 # Constrict the overall size of the shape when creating ERROR_IF tests
353 if error_name:
354 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
355 ifm_shape, max_dim=24, max_items=10000
356 )
357
358 # Get the filter height/width from the operator parameters
359 filter_hw = op["filter"]
360
361 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000362 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100363
364 # The filter dimensions are OHWI
365 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
366
367 # The bias is OC
368 bias_shape = np.asarray([ofm_depth])
369
370 return [ifm_shape, filter_shape, bias_shape]
371
372 @staticmethod
373 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
374 pl, const = op["operands"]
375
376 if error_name != ErrorIf.WrongRank:
377 assert rank == 4
378 assert pl == 1 and const == 2
379
380 # IFM dimensions are NHWC
381 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000382 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100383
384 # Constrict the overall size of the shape when creating ERROR_IF tests
385 if error_name:
386 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
387 ifm_shape, max_dim=24, max_items=10000
388 )
389
390 # Get the filter height/width from the operator parameters
391 # Filter is KH, HW, C, M
392 filter_hw = op["filter"]
393
394 # Generate a random OFM depth, but don't let it get too big because
395 # the output depth is M * C
396 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000397 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100398 ) + 1
399
400 # The filter dimensions are HWCM
401 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
402
403 # The bias is M * C
404 bias_shape = np.asarray([ifm_shape[3] * filter_m])
405
406 return [ifm_shape, filter_shape, bias_shape]
407
408 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000409 def tgFFT2d(testGen, op, rank, error_name=None):
410 pl, const = op["operands"]
411
412 if error_name != ErrorIf.WrongRank:
413 assert rank == 3
414 assert pl == 2 and const == 0
415
416 # IFM dimensions are NHW
417 ifm_shape = testGen.makeShape(rank)
418
419 # Select nearest lower power of two from input height and width
420 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
421 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
422
423 # Constrict the overall size of the shape when creating ERROR_IF tests
424 if error_name:
425 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
426
427 # Generate an invalid kernel that is not a power of two
428 if error_name == ErrorIf.KernelNotPowerOfTwo:
429 inc_h = 2 if ifm_shape[1] == 1 else 1
430 inc_w = 2 if ifm_shape[2] == 1 else 1
431 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
432 selected_inc = testGen.rng.choice(inc_choices)
433 ifm_shape[1] += selected_inc[0]
434 ifm_shape[2] += selected_inc[1]
435
436 ifm_shape = testGen.constrictBatchSize(ifm_shape)
437
438 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
439 if error_name == ErrorIf.FFTInputShapeMismatch:
440 modify_shape = testGen.rng.choice([0, 1])
441 # Only modify kernel (H, W)
442 modify_dim = testGen.rng.choice([1, 2])
443 ifm_shapes[modify_shape][modify_dim] *= 2
444
445 return [ifm_shapes[0], ifm_shapes[1]]
446
447 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000448 def tgRFFT2d(testGen, op, rank, error_name=None):
449 pl, const = op["operands"]
450
451 if error_name != ErrorIf.WrongRank:
452 assert rank == 3
453 assert pl == 1 and const == 0
454
455 # IFM dimensions are NHW
456 ifm_shape = testGen.makeShape(rank)
457
458 # Select nearest lower power of two from input height and width
459 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
460 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
461
462 # Constrict the overall size of the shape when creating ERROR_IF tests
463 if error_name:
464 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
465
466 # Generate an invalid kernel that is not a power of two
467 if error_name == ErrorIf.KernelNotPowerOfTwo:
468 # We must increment by 2 if current size is 1
469 inc_h = 2 if ifm_shape[1] == 1 else 1
470 inc_w = 2 if ifm_shape[2] == 1 else 1
471 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
472 selected_inc = testGen.rng.choice(inc_choices)
473 ifm_shape[1] += selected_inc[0]
474 ifm_shape[2] += selected_inc[1]
475
James Ward30124a82023-02-02 14:56:33 +0000476 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000477
478 return [ifm_shape]
479
480 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100481 def tgFullyConnected(testGen, op, rank, error_name=None):
482 pl, const = op["operands"]
483
484 if error_name != ErrorIf.WrongRank:
485 assert rank == 2
486
487 input_shape = testGen.makeShape(rank)
488
489 # Constrict the overall size of the shape when creating ERROR_IF tests
490 if error_name:
491 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
492
493 filter_oc = testGen.rng.integers(
494 low=testGen.args.tensor_shape_range[0],
495 high=testGen.args.tensor_shape_range[1],
496 size=1,
497 )[0]
498 filter_shape = np.asarray([filter_oc, input_shape[1]])
499
500 bias_shape = np.asarray([filter_oc])
501
502 return [input_shape, filter_shape, bias_shape]
503
504 @staticmethod
505 def tgMatmul(testGen, op, rank, error_name=None):
506 pl, const = op["operands"]
507
508 if error_name != ErrorIf.WrongRank:
509 assert rank == 3
510 assert pl == 2 and const == 0
511
512 a_shape = testGen.makeShape(rank)
513
514 # Constrict the overall size of the shape when creating ERROR_IF tests
515 if error_name:
516 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
517
518 # Get a random number for b_oc even if target shape is defined
519 b_oc = np.int32(
520 testGen.rng.integers(
521 low=testGen.args.tensor_shape_range[0],
522 high=testGen.args.tensor_shape_range[1],
523 size=1,
524 )
525 )[0]
526 # If N or H is large let b_oc be 1 to reduce output tensor size
527 if max(a_shape) > 1000:
528 b_oc = 1
529
530 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
531 return [a_shape, b_shape]
532
533 @staticmethod
534 def tgConcat(testGen, opName, rank, error_name=None):
535 pl, const = opName["operands"]
536 shape = testGen.makeShape(rank)
537
538 # Create extra tensors to concat.
539 # Take into account value of pl when getting maximum number of concats
540 num_tensors = testGen.randInt(0, 4)
541 shape_list = []
542 for i in range(pl + const + num_tensors):
543 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
544 remove = testGen.rng.choice([True, False])
545 wrongShape = shape.copy()
546
547 if remove and len(shape) > 1:
548 wrongShape = wrongShape[1:]
549 else:
550 wrongShape = list(wrongShape)
551 wrongShape.append(testGen.rng.integers(1, 10))
552
553 shape_list.append(wrongShape)
554 else:
555 shape_list.append(shape.copy())
556
557 return shape_list
558
559 @staticmethod
560 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
561 if error_name in [
562 ErrorIf.AxisSmallerZero,
563 ErrorIf.AxisLargerRank,
564 ErrorIf.ConcatInputRankMismatch,
565 ]:
566 return shapeList
567
568 # Split concat shape along axis to allow for multiple const inputs
569 # without making too many large tensors
570 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
571 # If axis can't be split we still need to invalidate other dimensions
572 if error_name == ErrorIf.ConcatInputDimMismatch:
573 for shape in shapeList[1:]:
574 # Negative test shapeLists are created individually for each test,
575 # so no need to copy the shape before altering it.
576 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
577 return shapeList
578
579 # Create copy of shape we are going to split (so we don't alter shapeList)
580 shape = shapeList[0].copy()
581 # Add original shape as first input
582 new_shapeList = [shape.copy()]
583 length_on_axis = shape[axis]
584 remaining_length = length_on_axis
585 for i in range(len(shapeList) - 2):
586 # Calculate split on axis and remaining value
587 split_shape_val = int(shape[axis] / 2)
588 remaining_length = remaining_length - split_shape_val
589
590 # Append new shape, and set remaining shape
591 shape[axis] = split_shape_val
592 new_shapeList.append(shape.copy())
593
594 # invalidate dimensions
595 if error_name == ErrorIf.ConcatInputDimMismatch:
596 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
597 else:
598 shape[axis] = remaining_length
599
600 if i == len(shapeList) - 3:
601 new_shapeList.append(shape.copy())
602
603 return new_shapeList
604
605
606class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100607 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100608
609 def __init__(self):
610 pass
611
Jeremy Johnson1271c442023-09-05 11:39:26 +0100612 class TVGInfo:
613 """Enhanced tensor values information including data gen dict."""
614
615 def __init__(self, tensorList, dataGenDict):
616 self.tensorList = tensorList
617 self.dataGenDict = dataGenDict
618
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100619 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000620 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100621 pCount, cCount = op["operands"]
622
623 tens = []
624 tens.extend(
625 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
626 )
627 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
628
629 return tens
630
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100631 # Default high value for random numbers
632 TVG_FLOAT_HIGH_VALUE = {
633 DType.FP32: (1 << 128) - (1 << (127 - 23)),
634 DType.FP16: (1 << 16) - (1 << (15 - 10)),
635 DType.BF16: (1 << 128) - (1 << (127 - 7)),
636 }
637
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100638 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000639 def _get_data_range(testGen, dtype, highValueLookup):
640 if dtype in highValueLookup:
641 data_range = testGen.getDTypeRange(dtype, high_inclusive=True)
642 high_val = highValueLookup[dtype]
643 # Set the values to something that won't produce infinity whilst
644 # respecting the default ranges if less than the high value
645 return [
646 max(-high_val, data_range[0]),
647 min(high_val, data_range[1]),
648 ]
649 return None
650
651 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100652 def tvgLazyGenDefault(
653 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
654 ):
655 # Variable inputs versus constants
656 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000657 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100658
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100659 if (
660 error_name is not None
661 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100662 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100663 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000664 # Fall back to internal data gen when dealing with unsupported types or ops
665 data_range = argsDict["data_range"] if "data_range" in argsDict else None
666 for idx, info in enumerate(zip(shapeList, dtypeList)):
667 shape, dtype = info
668 # Ignore lazy data gen option and create data array using any range limits
669 arr = testGen.getRandTensor(shape, dtype, data_range)
670 if idx < pCount:
671 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
672 else:
673 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100674
Jeremy Johnson1271c442023-09-05 11:39:26 +0100675 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
676
677 # Create data generator meta-data
678 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100679 tens_data = {
680 "version": "0.1",
681 "tensors": {},
682 }
683 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100684 for idx, shape in enumerate(shapeList):
685
686 tens_meta = {}
687 tens_meta["generator"] = gtu.DataGenType(dg_type).name
688 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
689 tens_meta["shape"] = [int(i) for i in shape]
690 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100691 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100692
693 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100694 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100695 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100696 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100697
698 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
699 info = {}
700 # TODO - generate seed for this generator based on test
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100701 info["rng_seed"] = 42
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100702 if "data_range" in argsDict:
703 data_range = argsDict["data_range"]
704 else:
705 data_range = testGen.getDTypeRange(
706 dtypeList[idx], high_inclusive=True
707 )
708 info["range"] = [str(v) for v in data_range]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100709 tens_meta["pseudo_random_info"] = info
710 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
711 info = {}
712 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100713 info["ks"] = int(argsDict["ks"])
714 if "acc_type" in argsDict:
715 # Convert type number into JSON name
716 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
717 "json"
718 ]
719 if "kernel" in argsDict:
720 info["kernel"] = [int(k) for k in argsDict["kernel"]]
721 if "axis" in argsDict:
722 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100723 tens_meta["dot_product_info"] = info
724 else:
725 # TODO - other data gen type
726 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100727
728 # Using the finished generate config meta data - generate the data if
729 # needed and assign a tensor name from the serializer
730
731 # Need to generate data when not lazy or for the bias tensor as we need
732 # to work out if the bias data is non-zero for compliance
733 if not testGen.args.lazy_data_gen or (
734 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
735 ):
736 # Give this tensor a temporary name until we get one from the serializer
737 temp_name = f"placeholder_{idx}"
738 dg_tens_meta[temp_name] = tens_meta
739 # Create data now using the temporary name to access meta details
740 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
741 # Remove the item as we will give it the correct name later
742 del dg_tens_meta[temp_name]
743
744 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
745 # The KS value used by compliance verification is altered when the
746 # bias data is non-zero
747 if max(abs(data)) > 0.0:
748 argsDict["ksb"] = argsDict["ks"] + 1
749
750 if testGen.args.lazy_data_gen:
751 data = None
752
753 if tens_meta["input_type"] == "VARIABLE":
754 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
755 else:
756 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
757
758 tens_ser_list.append(tens)
759 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100760 dg_tens_meta[tens.name] = tens_meta
761
Jeremy Johnson1271c442023-09-05 11:39:26 +0100762 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
763
764 @staticmethod
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000765 def tvgNegate(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100766 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000767 # Integer test
768 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100769 pCount, cCount = op["operands"]
770 assert (
771 pCount == 1 and cCount == 0
772 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100773 # Must create tensors with values within accumulator (int32) negatable
774 # range
775 max_val = (1 << 31) - 1
776 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100777 arr = np.int32(
778 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
779 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000780 tens_ser_list = []
781 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100782 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
783 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000784 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100785 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000786 # ERROR_IF or floating point test
787 return TosaTensorValuesGen.tvgLazyGenDefault(
788 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100789 )
790
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000791 # Set the data range to half the largest value
792 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
793 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
794 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
795 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
796 }
797
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100798 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000799 def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100800 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000801 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100802 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000803 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100804 pCount, cCount = op["operands"]
805 assert (
806 pCount == 2 and cCount == 0
807 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000808 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100809 add = op["op"] == Op.ADD
810 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
811 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
812 if add:
813 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
814 else:
815 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
816
817 # Work out the saturation limits
818 max_i32 = (1 << 31) - 1
819 min_i32 = -(1 << 31)
820 max_arr = np.full(shapeList[1], max_i32)
821 min_arr = np.full(shapeList[1], min_i32)
822
823 # Find how much values exceed the maximum/minimums
824 sat_max_arr = np.maximum(res_arr - max_arr, 0)
825 sat_min_arr = np.minimum(res_arr - min_arr, 0)
826
827 if not add:
828 # Swap saturation values and negate values as we need to perform opposite operations
829 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
830
831 # Create new array of unsaturated values by clipping values as needed
832 b_unsat_arr = b_arr
833 if (sat_max_arr != 0).any():
834 # Clip values that cause saturation
835 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
836 # Reduce axes in unsaturated tensor to match original tensor
837 for axis, dim in enumerate(b_arr.shape):
838 if dim != b_unsat_arr.shape[axis]:
839 assert (
840 dim == 1
841 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
842 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
843
844 if (sat_min_arr != 0).any():
845 # Clip values that cause saturation
846 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
847 # Reduce axes in unsaturated tensor to match original tensor
848 for axis, dim in enumerate(b_arr.shape):
849 if dim != b_unsat_arr.shape[axis]:
850 assert (
851 dim == 1
852 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
853 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
854
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000855 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100856 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
857 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000858 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100859 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
860 )
861
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000862 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100863 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000864 # ERROR_IF or floating point test
865 data_range = TosaTensorValuesGen._get_data_range(
866 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
867 )
868 if data_range:
869 argsDict["data_range"] = data_range
870
871 return TosaTensorValuesGen.tvgLazyGenDefault(
872 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100873 )
874
875 @staticmethod
876 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000877 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100878 ):
879 if dtypeList[0] in (
880 DType.INT32,
881 DType.INT16,
882 DType.INT8,
883 ):
884 # Limit input tensors with cond_if_binary or while_loop to stop
885 # saturation of add/sub ops with int32 and keep all logical shift
886 # values between 0 to 31 for int16 or int8
887 pCount, cCount = op["operands"]
888 pRemain = pCount
889 placeholders = []
890 for idx, shape in enumerate(shapeList[:]):
891 if dtypeList[0] == DType.INT32:
892 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
893 else:
894 arr = np.int32(
895 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
896 )
897 if pRemain > 0:
898 placeholders.append(
899 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
900 )
901 pRemain -= 1
902 else:
903 placeholders.append(
904 testGen.ser.addConst(shape, dtypeList[idx], arr)
905 )
906
907 return placeholders
908 else:
909 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000910 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100911 )
912
913 @staticmethod
914 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000915 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100916 ):
917 pCount, cCount = op["operands"]
918 # Force value of operand[1] to be within [0, num_bits]
919 assert (
920 pCount == 2 and cCount == 0
921 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
922
923 placeholders = []
924 for idx, shape in enumerate(shapeList[:]):
925 if idx == 1:
926 if dtypeList[idx] == DType.INT8:
927 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
928 elif dtypeList[idx] == DType.INT16:
929 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
930 elif dtypeList[idx] == DType.INT32:
931 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
932 elif error_name == ErrorIf.WrongInputType:
933 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
934 else:
935 raise Exception("OpArithmeticRightShift: invalid input dtype")
936 else:
937 arr = testGen.getRandTensor(shape, dtypeList[idx])
938 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
939
940 return placeholders
941
942 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000943 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100944 # Set datatype of condition tensor to boolean
945 dtypeList[0] = DType.BOOL
946
947 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000948 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100949 )
950
951 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000952 def tvgIntDiv(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100953 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000954 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100955 pCount, cCount = op["operands"]
956 assert (
957 pCount == 2 and cCount == 0
958 ), "Op.INTDIV must have 2 placeholders, 0 consts"
959
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000960 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100961
962 # Two invalid cases for Op.INTDIV:
963 # 1. divisor == 0
964 # 2. dividend == -(1<<31) and divisor == -1
965 while True:
966 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
967 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
968
969 if (divisor_arr == 0).any():
970 continue
971
972 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
973 continue
974
975 break
976
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000977 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100978 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
979 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000980 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100981 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
982 )
983
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000984 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100985 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000986 return TosaTensorValuesGen.tvgLazyGenDefault(
987 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100988 )
989
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100990 # Set the data range to the square root of the largest value
991 TVG_FLOAT_HIGH_VALUE_MUL = {
992 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
993 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
994 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
995 }
996
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100997 @staticmethod
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100998 def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
999 if error_name is not None or dtypeList[0] in (
1000 DType.FP16,
1001 DType.BF16,
1002 DType.FP32,
1003 ):
1004 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001005 data_range = TosaTensorValuesGen._get_data_range(
1006 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
1007 )
1008 if data_range:
1009 argsDict["data_range"] = data_range
1010
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001011 return TosaTensorValuesGen.tvgLazyGenDefault(
1012 testGen, opName, dtypeList, shapeList, argsDict, error_name
1013 )
1014 else:
1015 # Integer test
1016 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001017 pCount, cCount = op["operands"]
1018 assert (
1019 pCount == 2 and cCount == 0
1020 ), "Op.MUL must have 2 placeholders, 0 consts"
1021
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001022 tens_ser_list = []
1023
1024 # Make sure multiply result in int32 range
1025 shift = argsDict["shift"]
1026 if dtypeList[0] == DType.INT8:
1027 num_bits = 8
1028 elif dtypeList[0] == DType.INT16:
1029 num_bits = 16
1030 elif dtypeList[0] == DType.INT32:
1031 num_bits = 32
1032 elif error_name == ErrorIf.WrongInputType:
1033 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001034 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001035 raise Exception("OpMul: invalid input dtype")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001036
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001037 for idx, shape in enumerate(shapeList[:]):
1038 low = -(2 ** (num_bits - 1))
1039 high = (2 ** (num_bits - 1)) - 1
1040
1041 a_arr = np.int32(
1042 testGen.rng.integers(low=low, high=high, size=shapeList[0])
1043 )
1044 b_arr = np.int32(
1045 testGen.rng.integers(low=low, high=high, size=shapeList[1])
1046 )
1047
1048 i = 0
1049 while True:
1050
1051 a_arr_64 = a_arr.astype(np.int64)
1052 b_arr_64 = b_arr.astype(np.int64)
1053
1054 if shift > 0:
1055 rounding = 1 << (shift - 1)
1056 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001057 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001058 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001060 if (result_arr > -(2**31)).all() and (
1061 result_arr <= ((2**31) - 1)
1062 ).all():
1063 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001064
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001065 i = i + 1
1066 a_arr = a_arr // 2
1067 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001068
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001069 tens_ser_list.append(
1070 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001071 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001072 tens_ser_list.append(
1073 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1074 )
1075
1076 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001077
1078 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001079 def tvgConcat(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001080 count = len(shapeList) - testGen.args.num_const_inputs_concat
1081 if count < 1:
1082 count = 1
1083 if testGen.args.num_const_inputs_concat == 0:
1084 count = len(shapeList)
1085
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001086 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001087 testGen, shapeList, argsDict["axis"], error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001088 )
1089
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001090 tens_ser_list = []
1091 tens_ser_list.extend(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001092 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
1093 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001094 tens_ser_list.extend(
1095 testGen.buildConstTensors(shapeList[count:], dtypeList[count:])
1096 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001097
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001098 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001099
1100 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001101 def tvgLogicalShift(
1102 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1103 ):
1104 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001105 pCount, cCount = op["operands"]
1106 assert (
1107 pCount == 2 and cCount == 0
1108 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
1109 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1110 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001111 tens_ser_list = []
1112 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001113 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1114 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001115 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001116 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1117 )
1118
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001119 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001120
1121 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001122 def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001123 if error_name is None:
1124 pCount, cCount = op["operands"]
1125 assert (
1126 pCount == 2 and cCount == 0
1127 ), "Op.EQUAL must have 2 placeholders, 0 consts"
1128 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1129 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
1130 # Using random numbers means that it will be very unlikely that
1131 # there are any matching (equal) values, therefore force that
1132 # there are twice the number of matching values as the tensor rank
1133 for num in range(0, len(shapeList[0]) * 2):
1134 a_index = []
1135 b_index = []
1136 # Choose an index in each axis for the whole shape
1137 for axis in range(0, len(shapeList[0])):
1138 # Index can be up to the largest dimension in both shapes
1139 index = np.int32(
1140 testGen.rng.integers(
1141 0, max(shapeList[0][axis], shapeList[1][axis])
1142 )
1143 )
1144 # Reduce the index down to a shape's dim for broadcasting
1145 a_index.append(min(shapeList[0][axis] - 1, index))
1146 b_index.append(min(shapeList[1][axis] - 1, index))
1147
1148 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1149
1150 placeholders = []
1151 placeholders.append(
1152 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1153 )
1154 placeholders.append(
1155 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1156 )
1157 return placeholders
1158 else:
1159 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001160 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001161 )
1162
1163 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001164 def tvgReduceSum(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001165 if dtypeList[0] == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001166 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001167 pCount, cCount = op["operands"]
1168 assert (
1169 pCount == 1 and cCount == 0
1170 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1171 # Limit values so that the sum cannot exceed the range of an int32 during
1172 # summation of any axis
1173 range_val = int((1 << 31) / max(shapeList[0]))
1174 values_arr = np.int32(
1175 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1176 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001177 tens_ser_list = []
1178 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001179 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1180 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001181 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001182 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001183 # ERROR_IF or dot product floating point test
1184 return TosaTensorValuesGen.tvgLazyGenDefault(
1185 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001186 )
1187
1188
1189class TosaArgGen:
1190 """Argument generators create exhaustive or random lists of attributes for
1191 operators that take attributes or other parameters.
1192
1193 The return value is a list of (descriptive_name, [arglist]) tuples where
1194 the descriptive_name is appended to the test name and the arglist is expanded
1195 as arguments to the operator build function.
1196 """
1197
1198 def __init__(self):
1199 pass
1200
1201 @staticmethod
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001202 def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001203 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001204 if (
1205 error_name is None
1206 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1207 and gtu.dtypeIsSupportedByCompliance(dtype)
1208 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001209 if dtype in [DType.FP16, DType.FP32, DType.BF16]:
1210 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1211 else:
1212 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1213 else:
1214 # Error test or No data generator types listed - assume random
1215 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1216
1217 # Expand arg list with other data generator types
1218 new_arg_list = []
1219 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001220 for arg_str, args_dict in arg_list:
1221 args_dict["dg_type"] = dg_type
Jeremy Johnson1271c442023-09-05 11:39:26 +01001222 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
1223 # Default test
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001224 new_arg_list.append((arg_str, args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001225
1226 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1227 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001228 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001229 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001230 shape_info = (
1231 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1232 if "shape" in args_dict
1233 else ""
1234 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001235 print(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001236 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001237 )
1238 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001239 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001240 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001241 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001242
1243 for s in testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS:
1244 new_arg_str = f"{arg_str}_s{s}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001245 new_args_dict = args_dict.copy()
1246 new_args_dict["s"] = s
1247 new_arg_list.append((new_arg_str, new_args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001248
1249 return new_arg_list
1250
1251 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001252 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1253 """A trivial argument generator for operators that don't take any
1254 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001255 arg_list = TosaArgGen._add_data_generators(
1256 testGen,
1257 opName,
1258 dtype,
1259 [("", {})],
1260 error_name,
1261 )
1262 # Return list of tuples: (arg_str, args_dict)
1263 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001264
1265 @staticmethod
1266 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1267 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001268 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001269 shape = shapeList[0]
1270
1271 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001272 # Set too small axis
1273 axes = [testGen.rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001274 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001275 # Set too large axis
1276 axes = [testGen.rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001277 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001278 # Create tests for each dimension
1279 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001280
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001281 opid = testGen.TOSA_OP_LIST[opName]["op"]
1282
1283 for a in axes:
1284 args_dict = {"axis": int(a)}
1285 if opid == Op.REDUCE_SUM:
1286 args_dict["dot_products"] = gtu.product(shape)
1287 args_dict["shape"] = shape
1288 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1289 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1290
1291 arg_list.append(("axis{}".format(a), args_dict))
1292
1293 arg_list = TosaArgGen._add_data_generators(
1294 testGen,
1295 opName,
1296 dtype,
1297 arg_list,
1298 error_name,
1299 )
1300 # Return list of tuples: (arg_str, args_dict)
1301 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001302
1303 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001304 def _calculate_sparsity(num_tests, sparsity_factor):
1305 sparsity = num_tests // sparsity_factor + 1
1306 # If there are only a small number of tests, just select them all
1307 if sparsity < 13:
1308 sparsity = 1
1309 # To get a variety of parameter combinations sparsity should not be a
1310 # multiple of 2, 3 or 5
1311 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1312 sparsity += 1
1313 return sparsity
1314
1315 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001316 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001317 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001318 arg_list = []
1319
Jeremy Johnson0c716862023-04-13 17:18:19 +01001320 if testGen.args.level8k and error_name is not None:
1321 # Don't produce negative large tests
1322 return arg_list
1323
1324 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001325 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001326 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001327 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001328
Jeremy Johnson1271c442023-09-05 11:39:26 +01001329 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001330
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001331 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01001332 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001333 depthwise = opName.startswith("depthwise")
1334
1335 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001336 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001337 if error_name != ErrorIf.WrongRank:
1338 assert len(ifm_shape) == rank
1339 assert len(filter_shape) == rank
1340
Jeremy Johnson0c716862023-04-13 17:18:19 +01001341 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001342 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001343 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01001344 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001345 # compliance size - KS
1346 k_size = gtu.product(k_shape)
1347 if not depthwise:
1348 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001349
Jeremy Johnson0c716862023-04-13 17:18:19 +01001350 if not testGen.args.level8k:
1351 # Generate comprehensive argument lists
1352 # - except for named errors, which use specific invalid value(s)
1353 if error_name == ErrorIf.PadSmallerZero:
1354 p_vals = [testGen.rng.choice(range(-5, 0))]
1355 else:
1356 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1357 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1358 if error_name == ErrorIf.StrideSmallerOne:
1359 # Can't use stride=0, as it is used to derive output shape, as a divisor
1360 s_vals = [testGen.rng.choice(range(-5, 0))]
1361 else:
1362 # Stride must be greater than 1 to force non-integer error
1363 startStride = (
1364 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001365 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001366 s_vals = [
1367 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1368 ]
1369 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1370 if error_name == ErrorIf.DilationSmallerOne:
1371 d_vals = [testGen.rng.choice(range(-5, 1))]
1372 else:
1373 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1374 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001375
Jeremy Johnson0c716862023-04-13 17:18:19 +01001376 if not error_name and testGen.args.oversize:
1377 # add some oversize argument values
1378 if max(ifm_shape) < 64:
1379 bigPadding = 9
1380 paddings.update(
1381 {
1382 x
1383 for x in itertools.product(
1384 *([[0, bigPadding]] * (k_rank * 2))
1385 )
1386 }
1387 )
1388 bigStride = 8
1389 strides.update(
1390 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1391 )
1392 bigDilation = 7
1393 dilations.update(
1394 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1395 )
1396 max_dim_size = None
1397
1398 # There are too many parameter combinations, so generate them sparsely,
1399 # very sparse for negative tests
1400 sparsity_factor = 2 if error_name else 120
1401 sparsity = TosaArgGen._calculate_sparsity(
1402 len(paddings) * len(strides) * len(dilations), sparsity_factor
1403 )
1404 else:
1405 # Only test 8k levels boundaries
1406 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1407 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1408 bigPadding = bigKernel
1409
1410 dilation_shape = [1] * k_rank
1411 pad_shape = [0] * k_rank * 2
1412 if conv3d:
1413 # Small stride apart from for big kernel (see below) to keep
1414 # tensor size/calculation small
1415 stride_shape = [1] * k_rank
1416 for idx in range(k_rank):
1417 pad_offset = idx * 2
1418 if k_shape[idx] == bigKernel:
1419 # Padding shape needs to account for tensor shape
1420 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1421 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1422 # Big stride to reduce output size
1423 stride_shape[idx] = bigKernel
1424 else:
1425 # Account for kernel size
1426 pad_shape[pad_offset] = k_shape[idx] - 1
1427 else:
1428 # Always have a large stride with extra padding and dilation to keep
1429 # tensor calculation reasonable
1430 stride_shape = [bigKernel] * k_rank
1431 for idx in range(k_rank):
1432 # Dilation shape must account for kernel size
1433 dilation_shape[idx] = bigKernel // k_shape[idx]
1434 # Padding shape needs to accommodate tensor/kernel & dilation
1435 pad_offset = idx * 2
1436 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1437 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1438
1439 strides = {tuple(stride_shape)}
1440 dilations = {tuple(dilation_shape)}
1441 paddings = {tuple(pad_shape)}
1442 # Create a limit for the output dimensions size
1443 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1444
1445 # Currently allow all combinations that are reasonable size
1446 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001447
1448 n = 0
1449 for s in sorted(list(strides)):
1450 for p in sorted(list(paddings)):
1451 for d in sorted(list(dilations)):
1452 if (
1453 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001454 # the padded shape must exceed the dilation * kernel to get a positive
1455 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001456 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
1457 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01001458 and (
1459 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01001460 or (
1461 (ifm_shape[3] - 1 + p[4] + p[5])
1462 > d[2] * (k_shape[2] - 1)
1463 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01001464 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001465 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001466 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01001467 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001468 for index in range(k_rank):
1469 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001470 partial = (
1471 ifm_shape[index + 1]
1472 - 1
1473 + p[pad_offset]
1474 + p[pad_offset + 1]
1475 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001476 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001477 remainders.append(partial % s[index])
1478 outputs.append((partial // s[index]) + 1)
1479
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001480 if (
1481 # the parameters must produce integer exact output
1482 error_name != ErrorIf.ConvOutputShapeNonInteger
1483 and max(remainders) == 0
1484 ) or (
1485 error_name == ErrorIf.ConvOutputShapeNonInteger
1486 and max(remainders) > 0
1487 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001488 if (
1489 max_dim_size is not None
1490 and max(outputs) >= max_dim_size
1491 ):
1492 # Test will consume too much memory - skip it
1493 continue
1494
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001495 # Compliance - number of dot product calculations
1496 if depthwise:
1497 # TODO - add support
1498 dots = 0
1499 else:
1500 dots = gtu.product(
1501 (ifm_shape[0], *outputs, filter_shape[0])
1502 )
1503 args_dict = {
1504 "acc_type": accum_dtype,
1505 "stride": s,
1506 "pad": p,
1507 "dilation": d,
1508 "kernel": k_shape,
1509 "ks": k_size,
1510 "dot_products": dots,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001511 "shape": ifm_shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001512 }
1513
Jeremy Johnson0c716862023-04-13 17:18:19 +01001514 # Support for larger values than 9 needs different delimiter
1515 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001516 arg_list.append(
1517 (
James Ward8b390432022-08-12 20:48:56 +01001518 "acc{}_st{}_pad{}_dilat{}".format(
1519 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001520 delim.join([str(x) for x in s]),
1521 delim.join([str(x) for x in p]),
1522 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001523 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001524 args_dict,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001525 )
1526 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001527 n += 1
1528
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001529 arg_list = TosaArgGen._add_data_generators(
1530 testGen,
1531 opName,
1532 dtypes[0],
1533 arg_list,
1534 error_name,
1535 )
1536 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001537 return arg_list
1538
1539 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001540 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1541
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001542 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001543 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01001544
1545 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001546 accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01001547 elif error_name == ErrorIf.WrongInputType:
1548 # Pick some potentially correct output dtype if input type is incorrect
1549 accum_dtype = DType.INT32
1550 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001551 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001552
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001553 # Set up compliance info
1554 args_dict = {
1555 "acc_type": accum_dtype,
1556 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
1557 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
1558 "shape": shapeList[0],
1559 }
1560
1561 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
1562
1563 arg_list = TosaArgGen._add_data_generators(
1564 testGen,
1565 opName,
1566 input_dtype,
1567 arg_list,
1568 error_name,
1569 )
1570 # Return list of tuples: (arg_str, args_dict)
1571 return arg_list
James Ward8b390432022-08-12 20:48:56 +01001572
1573 @staticmethod
1574 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
1575 # Get valid accumulate type(s)
1576 if dtype == DType.INT8:
1577 accum_dtypes = [DType.INT32]
1578 elif dtype == DType.INT16:
1579 accum_dtypes = [DType.INT48]
1580 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001581 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001582 elif dtype == DType.BF16:
1583 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001584 elif dtype == DType.FP32:
1585 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001586 elif error_name is None:
1587 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
1588
1589 if error_name == ErrorIf.WrongOutputType:
1590 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson1271c442023-09-05 11:39:26 +01001591 accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01001592 elif error_name == ErrorIf.WrongInputType:
1593 # Pick some potentially correct output dtype if input type is incorrect
1594 accum_dtypes = [DType.INT32]
1595
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001596 # Set up compliance info
1597 args_dict = {
1598 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
1599 # Set dot_products = N*H*W
1600 "dot_products": gtu.product(
1601 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
1602 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001603 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001604 }
1605
1606 # Create arg tuple of string and dict
1607 arg_list = []
1608 for a in accum_dtypes:
1609 d = args_dict.copy()
1610 d["acc_type"] = a
1611 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001612
1613 arg_list = TosaArgGen._add_data_generators(
1614 testGen,
1615 opName,
1616 dtype,
1617 arg_list,
1618 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001619 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001620 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001621 return arg_list
James Ward8b390432022-08-12 20:48:56 +01001622
1623 @staticmethod
1624 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001625 arg_list = []
1626
Jeremy Johnson0c716862023-04-13 17:18:19 +01001627 if testGen.args.level8k and error_name is not None:
1628 # Don't produce negative large tests
1629 return arg_list
1630
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001631 ifm_shape = shapeList[0]
1632 filter_shape = shapeList[1]
1633
Jeremy Johnson1271c442023-09-05 11:39:26 +01001634 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001635
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001636 # Must be rank 4
1637 if error_name != ErrorIf.WrongRank:
1638 assert len(ifm_shape) == 4
1639 assert len(filter_shape) == 4
1640
Jeremy Johnson0c716862023-04-13 17:18:19 +01001641 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001642
Jeremy Johnson0c716862023-04-13 17:18:19 +01001643 if not testGen.args.level8k:
1644 # Generate comprehensive argument lists
1645 # - except for named errors, which use specific invalid value(s)
1646 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
1647 if error_name == ErrorIf.PadLargerEqualKernel:
1648 max_filter_size = -max(k_shape[0], k_shape[1])
1649 p_vals = [
1650 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
1651 ]
1652 else:
1653 p_vals = [
1654 x
1655 for x in range(
1656 smallest_padding_size, testGen.args.max_conv_padding + 1
1657 )
1658 ]
1659 paddings = {x for x in itertools.product(*([p_vals] * 4))}
1660 if error_name == ErrorIf.StrideSmallerOne:
1661 # Can't use stride=0, as it is used to derive output shape, as a divisor
1662 s_vals = [testGen.rng.choice(range(-5, 0))]
1663 else:
1664 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
1665 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001666
Jeremy Johnson0c716862023-04-13 17:18:19 +01001667 if not error_name and testGen.args.oversize:
1668 # add some oversize argument values
1669 if max(ifm_shape) < 64:
1670 bigPadding = 9
1671 paddings.update(
1672 {
1673 x
1674 for x in itertools.product(
1675 *([[smallest_padding_size, bigPadding]] * 4)
1676 )
1677 }
1678 )
1679 bigStride = 8
1680 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
1681
1682 # There are too many parameter combinations, so generate them sparsely,
1683 # very sparse for negative tests
1684 sparsity_factor = 2 if error_name else 10
1685 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
1686 # If there are only a small number of tests, just select them all
1687 if sparsity < 13:
1688 sparsity = 1
1689 # To get a variety of parameter combinations sparsity should not be a
1690 # multiple of 2, 3 or 5
1691 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1692 sparsity += 1
1693 else:
1694 # Only test 8k levels boundaries
1695 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1696 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1697 bigPadding = bigKernel
1698
1699 pad_shape = [0] * (len(k_shape) * 2)
1700 stride_shape = [1] * len(k_shape)
1701 # The point at which input dimension combined with the stride will
1702 # create large output sizes!
1703 LARGE_SIZE = 2
1704 for idx in range(len(k_shape)):
1705 pad_offset = idx * 2
1706 if k_shape[idx] == bigKernel:
1707 # Set large stride
1708 stride_shape[idx] = bigKernel
1709 # Use negative output padding to reduce shape size
1710 pad_shape[pad_offset] = -(bigPadding - 1)
1711 if ifm_shape[idx + 1] > LARGE_SIZE:
1712 pad_shape[pad_offset + 1] = -(bigPadding - 1)
1713 else:
1714 # The other dimension should be the bigKernel
1715 alt_idx = 1 - idx
1716 if (
1717 k_shape[alt_idx] == bigKernel
1718 and ifm_shape[alt_idx + 1] < LARGE_SIZE
1719 ):
1720 # As the input is small, the large stride won't
1721 # affect the output so we can add some padding
1722 pad_shape[pad_offset + 1] = bigPadding
1723
1724 strides = {tuple(stride_shape)}
1725 paddings = {tuple(pad_shape)}
1726
1727 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001728 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001729
1730 n = 0
1731 for s in sorted(list(strides)):
1732 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07001733 if n % sparsity == 0:
1734 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001735 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
1736 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07001737 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001738
1739 # Support for larger values than 9 needs different delimiter
1740 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07001741 arg_list.append(
1742 (
James Ward8b390432022-08-12 20:48:56 +01001743 "acc{}_st{}_pad{}_os{}".format(
1744 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001745 delim.join([str(x) for x in s]),
1746 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07001747 "x".join([str(x) for x in os]),
1748 ),
James Ward8b390432022-08-12 20:48:56 +01001749 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001750 )
TatWai Chong24594f52022-06-08 00:48:04 -07001751 )
1752 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001753
1754 return arg_list
1755
1756 @staticmethod
1757 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001758 rank = len(shapeList[0])
1759
1760 # Exhaustively test combinations of padding on each side of each dimension
1761 # - the range of padding values is defined by pad_min and pad_max
1762 # - for padding >9, the name format needs to be more distinctive
1763 pad_min, pad_max = 0, 1
1764 pad_values = [x for x in range(pad_min, pad_max + 1)]
1765 if error_name == ErrorIf.PadSmallerZero:
1766 pad_values = [x for x in range(-2, 0)]
1767 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
1768 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
1769
1770 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
1771 pad_const_int = testGen.getRandNumberDType(dtype)
1772 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00001773 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001774 pad_const_int = 0
1775 pad_const_fp = testGen.getRandNumberDType(dtype)
1776 else:
1777 return []
1778
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001779 list_shape_pad_values = list(shape_pad_values)
1780 # If we are producing tests for rank 6 or greater use sparsity
1781 if len(list_shape_pad_values) > 1024:
1782 sparsity_factor = 2 if error_name else 120
1783 sparsity = TosaArgGen._calculate_sparsity(
1784 len(list_shape_pad_values), sparsity_factor
1785 )
1786 else:
1787 sparsity = 1
1788
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001789 # Build arg list
1790 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001791 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01001792 paddings = list(paddings)
1793 args_valid = True
1794
1795 if error_name == ErrorIf.PadSmallerZero:
1796 # Prevent negative output shapes while ensuring still testing for negative padding
1797 for i in range(rank):
1798 dim_after_padding = (
1799 paddings[i][0] + paddings[i][1] + shapeList[0][i]
1800 )
1801 if dim_after_padding < 1:
1802 paddings[i] = (0, 0)
1803 if all([p > -1 for p in paddings[i]]):
1804 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001805 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01001806 name = "pad"
1807 for r in range(rank):
1808 before, after = paddings[r]
1809 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001810 args_dict = {
1811 "pad": np.array(paddings),
1812 "pad_const_int": pad_const_int,
1813 "pad_const_fp": pad_const_fp,
1814 }
1815 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01001816
1817 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
1818 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001819
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001820 arg_list = TosaArgGen._add_data_generators(
1821 testGen,
1822 opName,
1823 dtype,
1824 arg_list,
1825 error_name,
1826 )
1827
1828 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001829 return arg_list
1830
1831 @staticmethod
1832 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
1833 arg_list = []
1834
1835 shape = shapeList[0]
1836 if error_name != ErrorIf.WrongRank:
1837 assert len(shape) == 4
1838
Jeremy Johnson0c716862023-04-13 17:18:19 +01001839 test_level8k = testGen.args.level8k and error_name is None
1840
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001841 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001842 startKernel = 2
1843 startPad = 0
1844 if not test_level8k:
1845 # Generate comprehensive argument lists
1846 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
1847 paddings = {x for x in itertools.product(*([p_vals] * 4))}
1848 # Stride must be greater than 1 to force non-integer error
1849 s_vals = [
1850 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
1851 ]
1852 strides = {x for x in itertools.product(*([s_vals] * 2))}
1853 k_vals = [
1854 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
1855 ]
1856 kernels = {x for x in itertools.product(*([k_vals] * 2))}
1857 max_dim_size = None
1858 else:
1859 # Only test 8k levels
1860 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1861 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1862 strides = {(1, bigStride), (bigStride, 4)}
1863 kernels = {(1, bigKernel), (bigKernel, 3)}
1864 paddings = set()
1865 for s in sorted(list(strides)):
1866 for k in sorted(list(kernels)):
1867 padding = []
1868 for idx in range(len(k)):
1869 total_padding = s[idx] - shape[idx + 1] + k[idx]
1870 while total_padding < 0:
1871 # Must meet: shape + padding > kernel
1872 total_padding += s[idx]
1873 if total_padding < k[idx]:
1874 padding.extend([0, total_padding])
1875 else:
1876 # Note this may produce padding >= k[idx] which is not
1877 # allowed - but will be ignored in the creation loop below
1878 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
1879 paddings.add(tuple(padding))
1880 # Create a limit for the output dimensions size
1881 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001882
James Ward8b390432022-08-12 20:48:56 +01001883 if opName == "max_pool2d":
1884 accum_dtypes = [None] # max_pool has no accumulate dtype
1885 elif dtype == DType.INT8 or dtype == DType.INT16:
1886 accum_dtypes = [DType.INT32]
1887 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001888 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001889 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001890 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001891 elif error_name is None:
1892 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
1893 else:
1894 # Set to something for the ErrorIf case which has
1895 # incorrect input data-type
1896 accum_dtypes = [DType.INT32]
1897
Jeremy Johnson0c716862023-04-13 17:18:19 +01001898 if not test_level8k:
1899 if testGen.args.oversize:
1900 # add some oversize argument values
1901 bigStride = 7
1902 bigKernel = 9
1903 strides.update(
1904 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001905 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001906 kernels.update(
1907 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
1908 )
1909 if max(shape) < 64:
1910 # padding must be less than the kernel size
1911 bigPadding = bigKernel - 1
1912 paddings.update(
1913 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
1914 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001915
Jeremy Johnson0c716862023-04-13 17:18:19 +01001916 # There are too many parameter combinations, so generate them sparsely,
1917 # very sparse for negative tests
1918 sparsity_factor = 2 if error_name else 500
1919 sparsity = (
1920 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
1921 )
1922 else:
1923 # We have already limited test output combinations for 8k tests
1924 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001925
James Ward8b390432022-08-12 20:48:56 +01001926 arg_str = (
1927 "acc{}_st{}_kern{}_pad{}"
1928 if accum_dtypes[0] is not None
1929 else "st{}_kern{}_pad{}"
1930 )
1931
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001932 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01001933 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001934 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01001935
1936 # Support for larger values than 9 needs different delimiter
1937 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01001938 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01001939 delim.join([str(x) for x in stride]),
1940 delim.join([str(x) for x in kern]),
1941 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01001942 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001943 args_dict = {
1944 "stride": stride,
1945 "pad": pad,
1946 "kernel": kern,
1947 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001948 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001949 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
1950 }
James Ward8b390432022-08-12 20:48:56 +01001951
1952 if accum is not None:
1953 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001954 args_dict["acc_type"] = accum
1955 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01001956
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001957 n = 0
James Ward8b390432022-08-12 20:48:56 +01001958 for a in accum_dtypes:
1959 for s in sorted(list(strides)):
1960 for p in sorted(list(paddings)):
1961 for k in sorted(list(kernels)):
1962 if error_name in [
1963 ErrorIf.StrideSmallerOne,
1964 ErrorIf.KernelSmallerOne,
1965 ErrorIf.PadSmallerZero,
1966 ErrorIf.PadLargerEqualKernel,
1967 ]:
1968 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
1969 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001970 )
James Ward8b390432022-08-12 20:48:56 +01001971 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001972 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001973 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001974 )
James Ward8b390432022-08-12 20:48:56 +01001975 elif (
1976 n % sparsity == 0
1977 # padding must not exceed the kernel size
1978 and p[0] < k[0]
1979 and p[1] < k[0]
1980 and p[2] < k[1]
1981 and p[3] < k[1]
1982 # the padded shape must exceed the kernel size
1983 and (shape[1] + p[0] + p[1]) > k[0]
1984 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001985 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001986 partial_h = shape[1] + p[0] + p[1] - k[0]
1987 partial_w = shape[2] + p[2] + p[3] - k[1]
1988 remainder_h = partial_h % s[0]
1989 remainder_w = partial_w % s[1]
1990 output_h = partial_h // s[0] + 1
1991 output_w = partial_w // s[1] + 1
1992 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01001993 if (
1994 # the parameters must produce integer exact output
1995 error_name != ErrorIf.PoolingOutputShapeNonInteger
1996 and remainder_h == 0
1997 and remainder_w == 0
1998 ) or (
1999 error_name == ErrorIf.PoolingOutputShapeNonInteger
2000 and (remainder_h != 0 or remainder_w != 0)
2001 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002002 if (
2003 max_dim_size is not None
2004 and max(output_h, output_w) > max_dim_size
2005 ):
2006 # Test will consume too much memory - skip it
2007 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002008 # Dot products = N*OH*OW*C
2009 dp = gtu.product(
2010 (shape[0], output_h, output_w, shape[3])
2011 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002012 arg_list.append(
2013 get_arg_list_element(a, s, p, k, dp, shape)
2014 )
James Ward8b390432022-08-12 20:48:56 +01002015 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002016
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002017 # Now add data generator types
2018 arg_list = TosaArgGen._add_data_generators(
2019 testGen,
2020 opName,
2021 dtype,
2022 arg_list,
2023 error_name,
2024 )
2025
2026 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002027 return arg_list
2028
2029 @staticmethod
2030 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
2031 arg_list = []
2032
2033 # Enumerate the output types here
2034 if error_name == ErrorIf.WrongOutputType:
2035 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
2036 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002037 dtypeList = [
2038 DType.BOOL,
2039 DType.INT16,
2040 DType.INT32,
2041 DType.FP16,
2042 DType.BF16,
2043 DType.FP32,
2044 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002045 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002046 dtypeList = [
2047 DType.BOOL,
2048 DType.INT8,
2049 DType.INT32,
2050 DType.FP16,
2051 DType.BF16,
2052 DType.FP32,
2053 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002054 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002055 dtypeList = [
2056 DType.BOOL,
2057 DType.INT8,
2058 DType.INT16,
2059 DType.FP16,
2060 DType.BF16,
2061 DType.FP32,
2062 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002063 elif inDtype == DType.BOOL:
2064 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002065 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00002066 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002067 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00002068 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002069 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00002070 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002071 elif error_name == ErrorIf.WrongInputType:
2072 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002073 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002074 else:
2075 raise Exception("Unexpected input dtype: {}".format(inDtype))
2076
2077 for dtype in dtypeList:
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002078 arg_list.append(("out{}".format(testGen.typeStr(dtype)), [dtype]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002079
2080 return arg_list
2081
2082 @staticmethod
2083 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
2084 arg_list = []
2085
2086 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002087 for outDtype in [
2088 DType.UINT8,
2089 DType.INT8,
2090 DType.INT16,
2091 DType.INT32,
2092 DType.UINT16,
2093 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002094 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002095 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002096 and error_name == ErrorIf.OutputZeroPointNotZero
2097 ):
2098 continue
2099 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002100 outDtype != DType.UINT16
2101 and error_name == ErrorIf.U16OutputZeroPointNotValid
2102 ) or (
2103 inDtype != DType.UINT16
2104 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002105 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002106 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002107 continue
2108 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002109 inDtype == DType.UINT8
2110 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002111 and error_name != ErrorIf.WrongOutputType
2112 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002113 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2114 continue
2115 if (
2116 inDtype not in [DType.INT8, DType.INT16]
2117 and outDtype == DType.UINT8
2118 and error_name != ErrorIf.WrongOutputType
2119 ):
2120 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2121 continue
2122 if (
2123 inDtype == DType.UINT16
2124 and outDtype != DType.INT16
2125 and error_name != ErrorIf.WrongOutputType
2126 ):
2127 # The only output dtype for UINT16 is INT16, skip all others
2128 continue
2129 if (
2130 inDtype != DType.INT16
2131 and outDtype == DType.UINT16
2132 and error_name != ErrorIf.WrongOutputType
2133 ):
2134 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002135 continue
2136 if (
2137 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002138 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002139 ):
2140 continue
2141
2142 for scale32 in [False, True]:
2143 if error_name == ErrorIf.ScaleTrue and not scale32:
2144 continue
2145 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2146 continue
2147 for double_round in [False, True]:
2148 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2149 continue
2150 for per_channel in [False, True]:
2151
2152 if (
2153 inDtype == DType.INT48
2154 and scale32
2155 and error_name != ErrorIf.ScaleTrue
2156 ):
2157 # Illegal condition. Must be scale32=False
2158 continue
2159 if (
2160 double_round
2161 and not scale32
2162 and error_name != ErrorIf.ScaleNotTrue
2163 ):
2164 # Illegal condition. ERROR_IF(!scale32 && double_round)
2165 continue
2166
2167 arg_list.append(
2168 (
2169 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002170 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002171 int(scale32),
2172 int(double_round),
2173 int(per_channel),
2174 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002175 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002176 )
2177 )
2178
2179 return arg_list
2180
2181 @staticmethod
2182 def agMul(testGen, opName, shapeList, dtype, error_name=None):
2183 arg_list = []
2184
2185 if dtype is DType.INT32:
2186 for p in range(testGen.args.num_rand_permutations):
2187
2188 shift = testGen.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002189 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002190 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002191 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002192
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002193 arg_list = TosaArgGen._add_data_generators(
2194 testGen,
2195 opName,
2196 dtype,
2197 arg_list,
2198 error_name,
2199 )
2200 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002201 return arg_list
2202
2203 @staticmethod
2204 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
2205 arg_list = []
2206
2207 arg_list.append(("roundTrue", [True]))
2208 arg_list.append(("roundFalse", [False]))
2209
2210 return arg_list
2211
Luke Hutton57287132023-02-06 14:54:18 +00002212 @staticmethod
2213 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2214 arg_list = []
2215
2216 arg_list.append(("inverseTrue", [True]))
2217 arg_list.append(("inverseFalse", [False]))
2218
2219 return arg_list
2220
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002221 # Helper function for reshape. Gets some factors of a larger number.
2222 @staticmethod
2223 def getFactors(val, start=1):
2224 factors = []
2225
2226 for i in range(start, int(np.sqrt(val)) + 1):
2227 if (val % i) == 0:
2228 factors.append(i)
2229
2230 return factors
2231
2232 @staticmethod
2233 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
2234 arg_list = []
2235
2236 origShape = shapeList[0]
2237
2238 totalElements = 1
2239 for s in origShape:
2240 totalElements *= s
2241
2242 # This code is NOT fast. Fortunately, the numbers are fairly small.
2243 factors = TosaArgGen.getFactors(totalElements)
2244
2245 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002246 # Rank from 1 to TOSA_TENSOR_MAX_RANK
2247 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002248 if len(factors) < newRank:
2249 continue
2250
2251 found = True
2252 # escape_counter breaks while loop if it continues on for too long
2253 escape_counter = 0
2254 while found:
2255 newShape = []
Jerry Ge264f7fa2023-04-21 22:49:57 +00002256 new_shape_inferred = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002257 # Generate newShape ensuring it isn't a duplicate
2258 remainingElements = totalElements
2259 shuffledFactors = testGen.rng.permutation(factors)
Jerry Ge264f7fa2023-04-21 22:49:57 +00002260 inferred_dim = testGen.rng.integers(1, newRank + 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002261 for i in range(1, newRank):
2262 # pick rank-1 factors
2263 newShape.append(shuffledFactors[0])
2264 remainingElements = remainingElements // shuffledFactors[0]
Jerry Ge264f7fa2023-04-21 22:49:57 +00002265 if i == inferred_dim:
2266 new_shape_inferred.append(-1)
2267 else:
2268 new_shape_inferred.append(shuffledFactors[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002269 shuffledFactors = testGen.rng.permutation(
2270 TosaArgGen.getFactors(remainingElements)
2271 )
2272 newShape.append(remainingElements)
Jerry Ge264f7fa2023-04-21 22:49:57 +00002273 if inferred_dim == newRank:
2274 new_shape_inferred.append(-1)
2275 else:
2276 new_shape_inferred.append(remainingElements)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002277
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002278 # Check for duplicates
2279 found = False
2280 for name, other_shape in arg_list:
2281 if other_shape[0] == newShape:
2282 found = True
2283 break
2284
2285 escape_counter += 1
2286 if escape_counter >= 100:
2287 break
2288
2289 if not found:
Jerry Ge264f7fa2023-04-21 22:49:57 +00002290 if error_name in [
2291 ErrorIf.ReshapeOutputSizeNonInteger,
2292 ErrorIf.ReshapeOutputSizeMultiInference,
2293 ]:
2294 if newRank < 2:
2295 # Need at least two dimensions
2296 continue
2297 # NOTE: Change inferred_dim starting offset from 1 to 0
2298 inferred_dim -= 1
2299 extra_dim = inferred_dim + testGen.rng.integers(1, newRank)
2300 extra_dim = extra_dim % newRank
2301 assert extra_dim != inferred_dim
2302 if error_name == ErrorIf.ReshapeOutputSizeNonInteger:
2303 elements = 1
2304 for i, dim_value in enumerate(new_shape_inferred):
2305 if i != inferred_dim and i != extra_dim:
2306 elements *= dim_value
2307 dim_value = new_shape_inferred[extra_dim]
2308 while totalElements % (elements * dim_value) == 0:
2309 dim_value += 1
2310 new_shape_inferred[extra_dim] = dim_value
2311 else:
2312 assert error_name == ErrorIf.ReshapeOutputSizeMultiInference
2313 new_shape_inferred[extra_dim] = -1
2314 else:
2315 arg_list.append(
2316 ("perm{}_rank{}_outdefined".format(p, newRank), [newShape])
2317 )
2318 if error_name != ErrorIf.TensorSizeInputOutputMismatch:
2319 arg_list.append(
2320 (
2321 "perm{}_rank{}_outinferred".format(p, newRank),
2322 [new_shape_inferred],
2323 )
2324 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002325
2326 return arg_list
2327
2328 @staticmethod
2329 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
2330 arg_list = []
2331
2332 ifm_shape = shapeList[0]
2333
2334 if error_name == ErrorIf.IndexOutsideBounds:
2335 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
2336 incorrect_small_index = range(-len(ifm_shape), 0)
2337 permutations = [p for p in itertools.permutations(incorrect_large_index)]
2338 permutations.extend(
2339 [p for p in itertools.permutations(incorrect_small_index)]
2340 )
2341 elif error_name == ErrorIf.IndexUsedTwice:
2342 # Create list with a duplicated index
2343 perm_range = list(range(len(ifm_shape)))
2344 index_choice = testGen.rng.choice(range(len(perm_range)))
2345 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
2346 permutations = [p for p in itertools.permutations(perm_range)]
2347
2348 else:
2349 # Get all permutations
2350 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
2351
2352 # Limit to possible permutations from shape dimension or argument setting
2353 limit = min(len(permutations), testGen.args.num_rand_permutations)
2354
2355 # Get random permutation generator that uses all permutations
2356 random_permutations = testGen.rng.permutation(permutations)
2357
2358 # Create list of required amount of permutations
2359 arg_list = [
2360 ("perm{}".format(p), [random_permutations[p].tolist()])
2361 for p in range(limit)
2362 ]
2363 return arg_list
2364
2365 @staticmethod
2366 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
2367 arg_list = []
2368
2369 ifm_shape = shapeList[0]
2370 rank = len(ifm_shape)
2371
2372 for p in range(testGen.args.num_rand_permutations):
2373 start = []
2374 size = []
2375
2376 valid = True
2377
2378 for i in range(rank):
2379 if ifm_shape[i] > 1:
2380 start.append(testGen.randInt(0, ifm_shape[i]))
2381 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
2382
2383 # Invalid slice size?
2384 if size[i] == 0:
2385 valid = False
2386 else:
2387 start.append(0)
2388 size.append(1)
2389
2390 if valid:
2391 # If ERROR_IF test required then incorrect start, size will be returned
2392 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
2393 testGen, error_name, ifm_shape, start, size
2394 )
2395 arg_list.append(("perm{}".format(p), [start, size]))
2396 return arg_list
2397
2398 @staticmethod
2399 def agTile(testGen, opName, shapeList, dtype, error_name=None):
2400 arg_list = []
2401
2402 ifm_shape = shapeList[0]
2403 rank = len(ifm_shape)
2404
2405 for p in range(testGen.args.num_rand_permutations):
2406
2407 # Pick a few random, but small multiple values
2408 # because otherwise this has a tendency to generate
2409 # enormous tensors
2410 multiples = []
2411 for i in range(rank):
2412 if ifm_shape[i] > 1000:
2413 # Multiple of 1 if ifm_shape dimension is large to reduce
2414 # tensor size
2415 multiples.append(1)
2416 elif max(ifm_shape) > 1000:
2417 multiples.append(2)
2418 else:
2419 multiples.append(testGen.randInt(1, 4))
2420 arg_list.append(("perm{}".format(p), [multiples]))
2421
2422 return arg_list
2423
2424 @staticmethod
2425 def agResize(testGen, opName, shapeList, dtype, error_name=None):
2426 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002427 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002428
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002429 def get_aspect_ratio_resize_params():
2430 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
2431 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
2432 invert = testGen.rng.choice((False, True))
2433 letterbox = testGen.rng.choice((False, True))
2434
2435 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
2436 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
2437 scale_y_d = scale_x_d = 1
2438 offset_x = offset_y = 0
2439
2440 if letterbox:
2441 max_border = scale_y_n
2442 border_y = testGen.randInt(low=0, high=max_border)
2443 border_x = 0
2444 else:
2445 # Pillarboxing
2446 border_y = 0
2447 max_border = scale_x_n
2448 border_x = testGen.randInt(low=0, high=max_border)
2449
2450 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2451 offset = (offset_y, offset_x)
2452 border = (border_y, border_x)
2453
2454 return scale, offset, border
2455
2456 def get_upscale_downscale_params():
2457 valid_params = False
2458 while not valid_params:
2459 upscale = testGen.rng.choice((False, True))
2460
2461 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
2462 origin_sampling = testGen.rng.choice((False, True))
2463
2464 if upscale:
2465 shift = testGen.randInt(low=1, high=4)
2466 scale_x_d = scale_y_d = 1
2467 scale_x_n = scale_y_n = (
2468 1 << shift if origin_sampling else 2 << shift
2469 )
2470 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
2471 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
2472 else:
2473 scale_x_n = 1
2474 scale_y_n = 1
2475
2476 # Return list of valid scale_*_d values (max value 4) given input dim shape
2477 def get_valid_denom(ifm_dim):
2478 return [x for x in range(1, 5) if ifm_dim % x == 1]
2479
2480 # Generate list of valid downscale values and choose one randomly
2481 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
2482 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
2483
2484 if not valid_scale_y_ds and not valid_scale_x_ds:
2485 # Bad parameters, skip
2486 continue
2487
2488 if not valid_scale_y_ds:
2489 scale_y_d = 1
2490 else:
2491 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
2492
2493 if not valid_scale_x_ds:
2494 scale_x_d = 1
2495 else:
2496 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
2497
2498 border_x = border_y = 0
2499 offset_y = testGen.randInt(0, 16 * scale_y_n)
2500 offset_x = testGen.randInt(0, 16 * scale_x_n)
2501 valid_params = True
2502
2503 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2504 offset = (offset_y, offset_x)
2505 border = (border_y, border_x)
2506 return scale, offset, border
2507
2508 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002509 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
2510 scale = scale_n / scale_d
2511 if scale > max_scale:
2512 factor = scale / max_scale
2513 new_scale_d = math.ceil(scale_d * factor)
2514 assert scale_n / new_scale_d <= max_scale
2515 scale_d = new_scale_d
2516 return scale_d
2517
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002518 # Scale
2519 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
2520 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
2521
2522 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
2523 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
2524
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002525 scale_y_d = fix_scale_to_max_scale(
2526 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2527 )
2528 scale_x_d = fix_scale_to_max_scale(
2529 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2530 )
2531
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002532 # Offsets and border within the scale
2533 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
2534 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
2535 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
2536 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
2537
2538 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2539 offset = (offset_y, offset_x)
2540 border = (border_y, border_x)
2541 return scale, offset, border
2542
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002543 def get_level_8k_params():
2544 # Create 64x scale - 64/1 to 2048/32
2545 scale_d = testGen.randInt(
2546 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
2547 )
2548 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
2549 # Create half to fifth scaling
2550 scale_d_alt = testGen.randInt(low=2, high=6)
2551 scale_n_alt = 1
2552 switch = testGen.rng.choice((False, True))
2553 if switch:
2554 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
2555 else:
2556 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
2557
2558 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
2559 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
2560 offset = (offset_y, offset_x)
2561 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
2562 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
2563 border = (border_y, border_x)
2564 return scale, offset, border
2565
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002566 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002567 # Exclude illegal {mode, type} configurations. Pick legal output types
2568 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
2569 outputDTypeList = [DType.INT8]
2570 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
2571 outputDTypeList = [DType.INT16]
2572 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
2573 outputDTypeList = [DType.INT32]
2574 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
2575 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01002576 elif dtype == DType.FP16:
2577 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01002578 elif dtype == DType.BF16:
2579 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002580 elif dtype == DType.FP32:
2581 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002582 elif error_name == ErrorIf.WrongInputType:
2583 # If an incorrect input type is used then we set a 'correct'
2584 # output type to avoid other errors
2585 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
2586 else:
2587 continue
2588
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002589 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
2590
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002591 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002592 perm = 0
2593 while perm < testGen.args.num_rand_permutations:
2594 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002595 if not testGen.args.level8k:
2596 _rnd_param_fn = testGen.rng.choice(
2597 (
2598 get_rand_params,
2599 get_upscale_downscale_params,
2600 get_aspect_ratio_resize_params,
2601 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002602 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002603 scale, offset, border = _rnd_param_fn()
2604 else:
2605 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002606
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002607 # Expand params for bounds-checking
2608 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
2609 (offset_y, offset_x) = offset
2610 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002611
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002612 # Make sure output dimensions OH and OW are integers
2613 partial_output_y = (
2614 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
2615 )
2616 partial_output_x = (
2617 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
2618 )
2619 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002620 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002621 if (
2622 partial_output_y % scale_y_d == 0
2623 and partial_output_x % scale_x_d == 0
2624 ):
2625 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002626 if perm > 0:
2627 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002628 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002629 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002630 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002631 while partial_output_y % scale_y_d != 0:
2632 scale_y_d -= 1
2633 while partial_output_x % scale_x_d != 0:
2634 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002635 # Make sure we are still within max scaling
2636 if (
2637 scale_y_n / scale_y_d
2638 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
2639 scale_x_n / scale_x_d
2640 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
2641 # Skip the test as it is using too large a scaling factor
2642 if perm > 0:
2643 perm += 1
2644 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002645
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002646 output_y = partial_output_y // scale_y_d + 1
2647 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002648
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002649 if (
2650 output_y >= testGen.args.max_resize_output_dim
2651 or output_x >= testGen.args.max_resize_output_dim
2652 ) and error_name is None:
2653 # Skip positive test if output dim will be too high
2654 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002655 if not testGen.args.level8k or perm > 0:
2656 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002657 continue
2658
2659 if (
2660 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01002661 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002662 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01002663 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002664 ):
2665 # Output dimensions out of scope
2666 if error_name is not None and perm > 0:
2667 # As long as we have one ERROR_IF test, don't worry
2668 # about creating all the other permutations
2669 perm += 1
2670 continue
2671
2672 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
2673 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01002674 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002675 and output_y - scale_y_d < 1
2676 )
2677 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01002678 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002679 and output_x - scale_x_d < 1
2680 )
2681 ):
2682 # Can't create a negative test with these params as it
2683 # will create invalid output size
2684 if perm > 0:
2685 perm += 1
2686 continue
2687
2688 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
2689 offset = [offset_y, offset_x]
2690 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002691
2692 # Common for all data types
2693 if error_name is not None:
2694 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002695 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002696 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002697 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002698 outputDTypeNew,
2699 ) = TosaErrorIfArgGen.eiResizeErrorIf(
2700 testGen,
2701 error_name,
2702 mode,
2703 dtype,
2704 shapeList,
2705 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002706 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002707 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002708 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002709 )
2710 else:
2711 outputDTypeNew = outputDType
2712
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002713 arg_to_append = (
2714 arg_str.format(
2715 "N" if mode == ResizeMode.NEAREST else "B",
2716 testGen.typeStr(outputDTypeNew),
2717 scale[0],
2718 scale[1],
2719 scale[2],
2720 scale[3],
2721 offset[0],
2722 offset[1],
2723 border[0],
2724 border[1],
2725 ),
2726 [
2727 mode,
2728 scale,
2729 offset,
2730 border,
2731 dtype,
2732 outputDTypeNew,
2733 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002734 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002735 if arg_to_append in arg_list:
2736 # Skip already generated test params
2737 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002738
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002739 # Valid permutation
2740 perm += 1
2741 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002742 return arg_list
2743
2744 @staticmethod
2745 def agTable(testGen, opName, shapeList, dtype, error_name=None):
2746 arg_list = []
2747
2748 if dtype == DType.INT8:
2749 table = np.int32(
2750 testGen.rng.integers(low=-128, high=128, size=[256])
2751 ).tolist()
2752 else: # INT16
2753 table = np.int32(
2754 testGen.rng.integers(low=-32768, high=32768, size=[513])
2755 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07002756 # Make sure all slopes are within REQUIRE min/max 16-bit int
2757 for idx in range(len(table) - 1):
2758 slope = table[idx + 1] - table[idx]
2759 # Alter the next table entry to force the slope to be ok
2760 if slope > 32767:
2761 table[idx + 1] -= slope - 32767
2762 if slope < -32768:
2763 table[idx + 1] -= slope + 32768
2764 slope = table[idx + 1] - table[idx]
2765 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002766 arg_list.append(
2767 (
2768 "",
2769 [table],
2770 )
2771 )
2772 return arg_list
2773
2774 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
2775 # CondIf generates the condition values here.
2776 # Convert to tensors in the build function, along with the
2777 # then and else blocks
2778 arg_list = []
2779
2780 for c in [False, True]:
2781 arg_list.append(("cond{}".format(int(c)), [c]))
2782
2783 return arg_list
2784
2785 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
2786 # While loop: 0 iterations, 1, more than 1
2787 arg_list = []
2788
2789 for iter in [0, 1, 4]:
2790 arg_list.append(("iter{}".format(iter), [iter]))
2791
2792 return arg_list