blob: 94b7172fc451c06f25854caeedd1c45511cfe0eb [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
19
20class TosaQuantGen:
21 """QuantizedInfo random generator helper functions.
22
23 Specify with 'qgen': in the operator defintion.
24 """
25
26 def __init__(self):
27 pass
28
29 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010031
32 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010033 if testGen.args.zeropoint is not None:
34 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035 return testGen.randInt(-128, 128)
36 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010037 if testGen.args.zeropoint is not None:
38 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 return testGen.randInt(0, 256)
40 elif error_name in [
41 ErrorIf.InputZeroPointNotZero,
42 ErrorIf.WeightZeroPointNotZero,
43 ErrorIf.OutputZeroPointNotZero,
44 ]:
45 zero_point = testGen.randInt(-128, 128)
46 if zero_point == 0:
47 zero_point = 1
48 return zero_point
49 return 0
50
51 @staticmethod
52 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010053 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000054 qinfo = [
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 TosaQuantGen.getZeroPoint(testGen, dtype),
57 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010058 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000059 qinfo = [
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
62 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010063 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000064 qinfo = [
65 TosaQuantGen.getZeroPoint(testGen, dtype),
66 TosaQuantGen.getZeroPoint(testGen, dtype),
67 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068 return qinfo
69
70 @staticmethod
71 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010072 if isinstance(dtype_or_dtypeList, list):
73 # a list of [input, weights, accumulator] dtypes
74 dtypeList = dtype_or_dtypeList
75 else:
76 # an int, [input, weights, accumulator] dtypes are the same
77 dtypeList = [dtype_or_dtypeList] * 3
78
79 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000080 qinfo = [
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
82 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
83 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000085 qinfo = [
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
87 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
88 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000090 qinfo = [
91 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
92 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
93 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 return qinfo
95
96 @staticmethod
97 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000099 qinfo = [
100 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
101 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
102 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000104 qinfo = [
105 TosaQuantGen.getZeroPoint(testGen, dtype),
106 TosaQuantGen.getZeroPoint(testGen, dtype),
107 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108 return qinfo
109
110 @staticmethod
111 def computeMultiplierAndShift(scaleFp, scale32):
112 # Derived from computeMultiplierAndShiftTosaScale32
113 # Provide a floating-point scaling factor and the scale32 parameter
114 # to compute the multiplier and shift
115
116 if scale32:
117 scaleBits = 31
118 else:
119 scaleBits = 15
120
121 m, shift = math.frexp(scaleFp)
122
123 if scaleFp < 0.0:
124 m = -m
125
126 multiplier = round(m * (1 << scaleBits))
127 assert multiplier <= (1 << scaleBits)
128
129 if multiplier == (1 << scaleBits):
130 multiplier = multiplier // 2
131 shift = shift + 1
132
133 shift = (-shift) + scaleBits
134 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
135 # scaleFp, scaleBits, m, multiplier, shift))
136
137 # Adjust multiplier such that shift is in allowed value range.
138 if shift == 0:
139 multiplier = multiplier // 4
140 shift = shift + 2
141 elif shift == 1:
142 multiplier = multiplier // 2
143 shift = shift + 1
144 elif shift == 63:
145 multiplier = multiplier * 2
146 shift = shift - 1
147
148 assert multiplier <= (1 << scaleBits)
149 assert shift >= 2 and shift <= 62
150
151 return multiplier, shift
152
153
154class TosaTensorGen:
155 """Tensor generators create a shape list for the placeholder and const tensor
156 data operands for the operator.
157
158 The actual random data is generated separately for each test.
159 """
160
161 def __init__(self):
162 pass
163
164 @staticmethod
165 def tgBasic(testGen, opName, rank, error_name=None):
166 pl, const = opName["operands"]
167 shape = testGen.makeShape(rank)
168
169 # Constrict the overall size of the shape when creating ERROR_IF tests
170 if error_name:
171 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
172
173 shape_list = []
174 for i in range(pl + const):
175 shape_list.append(shape.copy())
176
Luke Huttona4e48ca2023-02-22 11:53:48 +0000177 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 if error_name == ErrorIf.RankMismatch:
179 if rank == 1 and i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
181 elif i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
183
184 return shape_list
185
186 @staticmethod
187 def tgNHWC(testGen, opName, rank, error_name=None):
188 pl, const = opName["operands"]
189
190 if error_name != ErrorIf.WrongRank:
191 assert rank == 4
192
193 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000194 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name and error_name != ErrorIf.MaxDimExceeded:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
207 def tgScatter(testGen, opName, rank, error_name=None):
208 pl, const = opName["operands"]
209
210 assert pl == 2
211 assert const == 0
212 if error_name != ErrorIf.WrongRank:
213 assert rank == 3
214
215 values_in_shape = testGen.makeShape(rank)
216
217 # ignore max batch size if target shape is set
218 if testGen.args.max_batch_size and not testGen.args.target_shapes:
James Ward30124a82023-02-02 14:56:33 +0000219 values_in_shape[0] = min(values_in_shape[0], testGen.args.max_batch_size)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100220
221 W = testGen.randInt(
222 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
223 )
224 # Constrict W if one dimension is too large to keep tensor size reasonable
225 if max(values_in_shape) > 5000:
226 W = testGen.randInt(0, 16)
227
228 input_shape = [values_in_shape[0], W, values_in_shape[2]]
229
230 shape_list = []
231 shape_list.append(values_in_shape.copy())
232 shape_list.append(input_shape.copy())
233
234 return shape_list
235
236 @staticmethod
237 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
238 shape = testGen.makeShape(rank)
239
240 pl, const = op["operands"]
241
242 shape_list = []
243
244 # Choose one of the inputs to broadcast
245 # Note: Simplifies OutputShaper code if we don't change first shape for errors
246 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Jerry Ge135c9552023-05-23 20:59:32 +0000247 fuzz_idx = testGen.randInt(0, rank)
248
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100249 for i in range(pl + const):
250 shape_bcast = shape.copy()
251
Jerry Ge135c9552023-05-23 20:59:32 +0000252 # To test broadcasting, the chosen fuzz index dimension should not be 1
253 if shape_bcast[fuzz_idx] == 1:
254 shape_bcast[fuzz_idx] += 1
255
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100256 # If the chosen input, pick a random index to broadcast
257 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000258 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100259 # Add one rank to the shape (or more for rank of 1)
260 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
261 shape_bcast = np.concatenate(
262 (shape_bcast, testGen.makeShape(extra_ranks))
263 )
264 if rank != 1:
265 # Either keep the extra rank, or remove it
266 new_len = testGen.rng.choice([-2, len(shape_bcast)])
267 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000268 elif error_name == ErrorIf.BroadcastShapesMismatch:
269 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100270 else:
271 shape_bcast[fuzz_idx] = 1
272
273 shape_list.append(shape_bcast)
274
275 return shape_list
276
277 @staticmethod
278 def tgConv2D(testGen, op, rank, error_name=None):
279 pl, const = op["operands"]
280
281 if error_name != ErrorIf.WrongRank:
282 assert rank == 4
283
284 # IFM dimensions are NHWC
285 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000286 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100287
288 # Constrict the overall size of the shape when creating ERROR_IF tests
289 if error_name:
290 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
291 ifm_shape, max_dim=24, max_items=10000
292 )
293
294 # Get the filter height/width from the operator parameters
295 filter_hw = op["filter"]
296
297 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000298 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100299
300 # The filter dimensions are OHWI
301 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
302
303 # The bias is OC
304 bias_shape = np.asarray([ofm_depth])
305
306 return [ifm_shape, filter_shape, bias_shape]
307
308 @staticmethod
309 def tgConv3D(testGen, op, rank, error_name=None):
310 pl, const = op["operands"]
311
312 if error_name != ErrorIf.WrongRank:
313 assert rank == 5
314
315 # IFM dimensions are NDHWC
316 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000317 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100318
319 # Constrict the overall size of the shape when creating ERROR_IF tests
320 if error_name:
321 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
322 ifm_shape, max_dim=24, max_items=10000
323 )
324
325 # Get the filter depth/height/width from the operator parameters
326 filter_dhw = op["filter"]
327
328 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000329 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100330
331 # The filter dimensions are ODHWI
332 filter_shape = np.asarray(
333 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
334 )
335
336 # The bias is OC
337 bias_shape = np.asarray([ofm_channel])
338
339 return [ifm_shape, filter_shape, bias_shape]
340
341 @staticmethod
342 def tgTransposeConv2D(testGen, op, rank, error_name=None):
343 pl, const = op["operands"]
344
345 if error_name != ErrorIf.WrongRank:
346 assert rank == 4
347
348 # IFM dimensions are NHWC
349 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000350 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100351
352 # Constrict the overall size of the shape when creating ERROR_IF tests
353 if error_name:
354 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
355 ifm_shape, max_dim=24, max_items=10000
356 )
357
358 # Get the filter height/width from the operator parameters
359 filter_hw = op["filter"]
360
361 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000362 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100363
364 # The filter dimensions are OHWI
365 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
366
367 # The bias is OC
368 bias_shape = np.asarray([ofm_depth])
369
370 return [ifm_shape, filter_shape, bias_shape]
371
372 @staticmethod
373 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
374 pl, const = op["operands"]
375
376 if error_name != ErrorIf.WrongRank:
377 assert rank == 4
378 assert pl == 1 and const == 2
379
380 # IFM dimensions are NHWC
381 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000382 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100383
384 # Constrict the overall size of the shape when creating ERROR_IF tests
385 if error_name:
386 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
387 ifm_shape, max_dim=24, max_items=10000
388 )
389
390 # Get the filter height/width from the operator parameters
391 # Filter is KH, HW, C, M
392 filter_hw = op["filter"]
393
394 # Generate a random OFM depth, but don't let it get too big because
395 # the output depth is M * C
396 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000397 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100398 ) + 1
399
400 # The filter dimensions are HWCM
401 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
402
403 # The bias is M * C
404 bias_shape = np.asarray([ifm_shape[3] * filter_m])
405
406 return [ifm_shape, filter_shape, bias_shape]
407
408 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000409 def tgFFT2d(testGen, op, rank, error_name=None):
410 pl, const = op["operands"]
411
412 if error_name != ErrorIf.WrongRank:
413 assert rank == 3
414 assert pl == 2 and const == 0
415
416 # IFM dimensions are NHW
417 ifm_shape = testGen.makeShape(rank)
418
419 # Select nearest lower power of two from input height and width
420 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
421 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
422
423 # Constrict the overall size of the shape when creating ERROR_IF tests
424 if error_name:
425 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
426
427 # Generate an invalid kernel that is not a power of two
428 if error_name == ErrorIf.KernelNotPowerOfTwo:
429 inc_h = 2 if ifm_shape[1] == 1 else 1
430 inc_w = 2 if ifm_shape[2] == 1 else 1
431 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
432 selected_inc = testGen.rng.choice(inc_choices)
433 ifm_shape[1] += selected_inc[0]
434 ifm_shape[2] += selected_inc[1]
435
436 ifm_shape = testGen.constrictBatchSize(ifm_shape)
437
438 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
439 if error_name == ErrorIf.FFTInputShapeMismatch:
440 modify_shape = testGen.rng.choice([0, 1])
441 # Only modify kernel (H, W)
442 modify_dim = testGen.rng.choice([1, 2])
443 ifm_shapes[modify_shape][modify_dim] *= 2
444
445 return [ifm_shapes[0], ifm_shapes[1]]
446
447 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000448 def tgRFFT2d(testGen, op, rank, error_name=None):
449 pl, const = op["operands"]
450
451 if error_name != ErrorIf.WrongRank:
452 assert rank == 3
453 assert pl == 1 and const == 0
454
455 # IFM dimensions are NHW
456 ifm_shape = testGen.makeShape(rank)
457
458 # Select nearest lower power of two from input height and width
459 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
460 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
461
462 # Constrict the overall size of the shape when creating ERROR_IF tests
463 if error_name:
464 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
465
466 # Generate an invalid kernel that is not a power of two
467 if error_name == ErrorIf.KernelNotPowerOfTwo:
468 # We must increment by 2 if current size is 1
469 inc_h = 2 if ifm_shape[1] == 1 else 1
470 inc_w = 2 if ifm_shape[2] == 1 else 1
471 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
472 selected_inc = testGen.rng.choice(inc_choices)
473 ifm_shape[1] += selected_inc[0]
474 ifm_shape[2] += selected_inc[1]
475
James Ward30124a82023-02-02 14:56:33 +0000476 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000477
478 return [ifm_shape]
479
480 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100481 def tgFullyConnected(testGen, op, rank, error_name=None):
482 pl, const = op["operands"]
483
484 if error_name != ErrorIf.WrongRank:
485 assert rank == 2
486
487 input_shape = testGen.makeShape(rank)
488
489 # Constrict the overall size of the shape when creating ERROR_IF tests
490 if error_name:
491 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
492
493 filter_oc = testGen.rng.integers(
494 low=testGen.args.tensor_shape_range[0],
495 high=testGen.args.tensor_shape_range[1],
496 size=1,
497 )[0]
498 filter_shape = np.asarray([filter_oc, input_shape[1]])
499
500 bias_shape = np.asarray([filter_oc])
501
502 return [input_shape, filter_shape, bias_shape]
503
504 @staticmethod
505 def tgMatmul(testGen, op, rank, error_name=None):
506 pl, const = op["operands"]
507
508 if error_name != ErrorIf.WrongRank:
509 assert rank == 3
510 assert pl == 2 and const == 0
511
512 a_shape = testGen.makeShape(rank)
513
514 # Constrict the overall size of the shape when creating ERROR_IF tests
515 if error_name:
516 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
517
518 # Get a random number for b_oc even if target shape is defined
519 b_oc = np.int32(
520 testGen.rng.integers(
521 low=testGen.args.tensor_shape_range[0],
522 high=testGen.args.tensor_shape_range[1],
523 size=1,
524 )
525 )[0]
526 # If N or H is large let b_oc be 1 to reduce output tensor size
527 if max(a_shape) > 1000:
528 b_oc = 1
529
530 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
531 return [a_shape, b_shape]
532
533 @staticmethod
534 def tgConcat(testGen, opName, rank, error_name=None):
535 pl, const = opName["operands"]
536 shape = testGen.makeShape(rank)
537
538 # Create extra tensors to concat.
539 # Take into account value of pl when getting maximum number of concats
540 num_tensors = testGen.randInt(0, 4)
541 shape_list = []
542 for i in range(pl + const + num_tensors):
543 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
544 remove = testGen.rng.choice([True, False])
545 wrongShape = shape.copy()
546
547 if remove and len(shape) > 1:
548 wrongShape = wrongShape[1:]
549 else:
550 wrongShape = list(wrongShape)
551 wrongShape.append(testGen.rng.integers(1, 10))
552
553 shape_list.append(wrongShape)
554 else:
555 shape_list.append(shape.copy())
556
557 return shape_list
558
559 @staticmethod
560 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
561 if error_name in [
562 ErrorIf.AxisSmallerZero,
563 ErrorIf.AxisLargerRank,
564 ErrorIf.ConcatInputRankMismatch,
565 ]:
566 return shapeList
567
568 # Split concat shape along axis to allow for multiple const inputs
569 # without making too many large tensors
570 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
571 # If axis can't be split we still need to invalidate other dimensions
572 if error_name == ErrorIf.ConcatInputDimMismatch:
573 for shape in shapeList[1:]:
574 # Negative test shapeLists are created individually for each test,
575 # so no need to copy the shape before altering it.
576 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
577 return shapeList
578
579 # Create copy of shape we are going to split (so we don't alter shapeList)
580 shape = shapeList[0].copy()
581 # Add original shape as first input
582 new_shapeList = [shape.copy()]
583 length_on_axis = shape[axis]
584 remaining_length = length_on_axis
585 for i in range(len(shapeList) - 2):
586 # Calculate split on axis and remaining value
587 split_shape_val = int(shape[axis] / 2)
588 remaining_length = remaining_length - split_shape_val
589
590 # Append new shape, and set remaining shape
591 shape[axis] = split_shape_val
592 new_shapeList.append(shape.copy())
593
594 # invalidate dimensions
595 if error_name == ErrorIf.ConcatInputDimMismatch:
596 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
597 else:
598 shape[axis] = remaining_length
599
600 if i == len(shapeList) - 3:
601 new_shapeList.append(shape.copy())
602
603 return new_shapeList
604
605
606class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100607 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100608
609 def __init__(self):
610 pass
611
Jeremy Johnson1271c442023-09-05 11:39:26 +0100612 class TVGInfo:
613 """Enhanced tensor values information including data gen dict."""
614
615 def __init__(self, tensorList, dataGenDict):
616 self.tensorList = tensorList
617 self.dataGenDict = dataGenDict
618
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100619 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000620 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100621 pCount, cCount = op["operands"]
622
623 tens = []
624 tens.extend(
625 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
626 )
627 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
628
629 return tens
630
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100631 # Default high value for random numbers
632 TVG_FLOAT_HIGH_VALUE = {
633 DType.FP32: (1 << 128) - (1 << (127 - 23)),
634 DType.FP16: (1 << 16) - (1 << (15 - 10)),
635 DType.BF16: (1 << 128) - (1 << (127 - 7)),
636 }
637
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100638 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100639 def tvgLazyGenDefault(
640 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
641 ):
642 # Variable inputs versus constants
643 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
644
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100645 if (
646 error_name is not None
647 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100648 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100649 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100650 # Fall back to original path when dealing with unsupported types or ops
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100651
652 # First turn off lazy data gen so we always produce data
653 lazy_data_gen = testGen.args.lazy_data_gen
Jeremy Johnson1271c442023-09-05 11:39:26 +0100654 testGen.args.lazy_data_gen = False
655
Jeremy Johnson1271c442023-09-05 11:39:26 +0100656 tens_ser_list = TosaTensorValuesGen.tvgDefault(
657 testGen,
658 testGen.TOSA_OP_LIST[opName],
659 dtypeList,
660 shapeList,
661 [],
662 error_name,
663 )
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100664 # Restore lazy data gen setting
665 testGen.args.lazy_data_gen = lazy_data_gen
Jeremy Johnson1271c442023-09-05 11:39:26 +0100666 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
667
668 # Create data generator meta-data
669 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100670 tens_data = {
671 "version": "0.1",
672 "tensors": {},
673 }
674 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100675 tens_ser_list = []
676 for idx, shape in enumerate(shapeList):
677
678 tens_meta = {}
679 tens_meta["generator"] = gtu.DataGenType(dg_type).name
680 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
681 tens_meta["shape"] = [int(i) for i in shape]
682 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100683 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100684
685 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100686 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100687 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100688 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100689
690 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
691 info = {}
692 # TODO - generate seed for this generator based on test
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100693 info["rng_seed"] = 42
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100694 if "data_range" in argsDict:
695 data_range = argsDict["data_range"]
696 else:
697 data_range = testGen.getDTypeRange(
698 dtypeList[idx], high_inclusive=True
699 )
700 info["range"] = [str(v) for v in data_range]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100701 tens_meta["pseudo_random_info"] = info
702 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
703 info = {}
704 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100705 info["ks"] = int(argsDict["ks"])
706 if "acc_type" in argsDict:
707 # Convert type number into JSON name
708 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
709 "json"
710 ]
711 if "kernel" in argsDict:
712 info["kernel"] = [int(k) for k in argsDict["kernel"]]
713 if "axis" in argsDict:
714 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100715 tens_meta["dot_product_info"] = info
716 else:
717 # TODO - other data gen type
718 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100719
720 # Using the finished generate config meta data - generate the data if
721 # needed and assign a tensor name from the serializer
722
723 # Need to generate data when not lazy or for the bias tensor as we need
724 # to work out if the bias data is non-zero for compliance
725 if not testGen.args.lazy_data_gen or (
726 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
727 ):
728 # Give this tensor a temporary name until we get one from the serializer
729 temp_name = f"placeholder_{idx}"
730 dg_tens_meta[temp_name] = tens_meta
731 # Create data now using the temporary name to access meta details
732 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
733 # Remove the item as we will give it the correct name later
734 del dg_tens_meta[temp_name]
735
736 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
737 # The KS value used by compliance verification is altered when the
738 # bias data is non-zero
739 if max(abs(data)) > 0.0:
740 argsDict["ksb"] = argsDict["ks"] + 1
741
742 if testGen.args.lazy_data_gen:
743 data = None
744
745 if tens_meta["input_type"] == "VARIABLE":
746 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
747 else:
748 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
749
750 tens_ser_list.append(tens)
751 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100752 dg_tens_meta[tens.name] = tens_meta
753
Jeremy Johnson1271c442023-09-05 11:39:26 +0100754 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
755
756 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000757 def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100758 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100759 pCount, cCount = op["operands"]
760 assert (
761 pCount == 1 and cCount == 0
762 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100763 # Must create tensors with values within accumulator (int32) negatable
764 # range
765 max_val = (1 << 31) - 1
766 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100767 arr = np.int32(
768 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
769 )
770 placeholders = []
771 placeholders.append(
772 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
773 )
774 return placeholders
775 else:
776 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000777 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100778 )
779
780 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000781 def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100782 if dtypeList[0] == DType.INT32 and error_name is None:
783 # Make sure the operation does not cause value saturation - where
784 # the number wraps due to limited number of bits to store the answer
785 pCount, cCount = op["operands"]
786 assert (
787 pCount == 2 and cCount == 0
788 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
789 placeholders = []
790 add = op["op"] == Op.ADD
791 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
792 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
793 if add:
794 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
795 else:
796 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
797
798 # Work out the saturation limits
799 max_i32 = (1 << 31) - 1
800 min_i32 = -(1 << 31)
801 max_arr = np.full(shapeList[1], max_i32)
802 min_arr = np.full(shapeList[1], min_i32)
803
804 # Find how much values exceed the maximum/minimums
805 sat_max_arr = np.maximum(res_arr - max_arr, 0)
806 sat_min_arr = np.minimum(res_arr - min_arr, 0)
807
808 if not add:
809 # Swap saturation values and negate values as we need to perform opposite operations
810 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
811
812 # Create new array of unsaturated values by clipping values as needed
813 b_unsat_arr = b_arr
814 if (sat_max_arr != 0).any():
815 # Clip values that cause saturation
816 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
817 # Reduce axes in unsaturated tensor to match original tensor
818 for axis, dim in enumerate(b_arr.shape):
819 if dim != b_unsat_arr.shape[axis]:
820 assert (
821 dim == 1
822 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
823 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
824
825 if (sat_min_arr != 0).any():
826 # Clip values that cause saturation
827 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
828 # Reduce axes in unsaturated tensor to match original tensor
829 for axis, dim in enumerate(b_arr.shape):
830 if dim != b_unsat_arr.shape[axis]:
831 assert (
832 dim == 1
833 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
834 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
835
836 placeholders.append(
837 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
838 )
839 placeholders.append(
840 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
841 )
842
843 return placeholders
844 else:
845 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000846 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100847 )
848
849 @staticmethod
850 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000851 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100852 ):
853 if dtypeList[0] in (
854 DType.INT32,
855 DType.INT16,
856 DType.INT8,
857 ):
858 # Limit input tensors with cond_if_binary or while_loop to stop
859 # saturation of add/sub ops with int32 and keep all logical shift
860 # values between 0 to 31 for int16 or int8
861 pCount, cCount = op["operands"]
862 pRemain = pCount
863 placeholders = []
864 for idx, shape in enumerate(shapeList[:]):
865 if dtypeList[0] == DType.INT32:
866 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
867 else:
868 arr = np.int32(
869 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
870 )
871 if pRemain > 0:
872 placeholders.append(
873 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
874 )
875 pRemain -= 1
876 else:
877 placeholders.append(
878 testGen.ser.addConst(shape, dtypeList[idx], arr)
879 )
880
881 return placeholders
882 else:
883 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000884 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100885 )
886
887 @staticmethod
888 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000889 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100890 ):
891 pCount, cCount = op["operands"]
892 # Force value of operand[1] to be within [0, num_bits]
893 assert (
894 pCount == 2 and cCount == 0
895 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
896
897 placeholders = []
898 for idx, shape in enumerate(shapeList[:]):
899 if idx == 1:
900 if dtypeList[idx] == DType.INT8:
901 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
902 elif dtypeList[idx] == DType.INT16:
903 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
904 elif dtypeList[idx] == DType.INT32:
905 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
906 elif error_name == ErrorIf.WrongInputType:
907 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
908 else:
909 raise Exception("OpArithmeticRightShift: invalid input dtype")
910 else:
911 arr = testGen.getRandTensor(shape, dtypeList[idx])
912 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
913
914 return placeholders
915
916 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000917 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100918 # Set datatype of condition tensor to boolean
919 dtypeList[0] = DType.BOOL
920
921 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000922 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100923 )
924
925 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000926 def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100927 if error_name is None:
928 pCount, cCount = op["operands"]
929 assert (
930 pCount == 2 and cCount == 0
931 ), "Op.INTDIV must have 2 placeholders, 0 consts"
932
933 placeholders = []
934
935 # Two invalid cases for Op.INTDIV:
936 # 1. divisor == 0
937 # 2. dividend == -(1<<31) and divisor == -1
938 while True:
939 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
940 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
941
942 if (divisor_arr == 0).any():
943 continue
944
945 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
946 continue
947
948 break
949
950 placeholders.append(
951 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
952 )
953 placeholders.append(
954 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
955 )
956
957 return placeholders
958 else:
959 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000960 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100961 )
962
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100963 # Set the data range to the square root of the largest value
964 TVG_FLOAT_HIGH_VALUE_MUL = {
965 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
966 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
967 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
968 }
969
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100970 @staticmethod
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100971 def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
972 if error_name is not None or dtypeList[0] in (
973 DType.FP16,
974 DType.BF16,
975 DType.FP32,
976 ):
977 # ERROR_IF or floating point test
978 if dtypeList[0] in TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL:
979 data_range = testGen.getDTypeRange(dtypeList[0], high_inclusive=True)
980 high_val = TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL[dtypeList[0]]
981 # Set the values to something that won't produce infinity whilst
982 # respecting the default ranges if less than the high value
983 argsDict["data_range"] = [
984 max(-high_val, data_range[0]),
985 min(high_val, data_range[1]),
986 ]
987 return TosaTensorValuesGen.tvgLazyGenDefault(
988 testGen, opName, dtypeList, shapeList, argsDict, error_name
989 )
990 else:
991 # Integer test
992 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100993 pCount, cCount = op["operands"]
994 assert (
995 pCount == 2 and cCount == 0
996 ), "Op.MUL must have 2 placeholders, 0 consts"
997
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100998 tens_ser_list = []
999
1000 # Make sure multiply result in int32 range
1001 shift = argsDict["shift"]
1002 if dtypeList[0] == DType.INT8:
1003 num_bits = 8
1004 elif dtypeList[0] == DType.INT16:
1005 num_bits = 16
1006 elif dtypeList[0] == DType.INT32:
1007 num_bits = 32
1008 elif error_name == ErrorIf.WrongInputType:
1009 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001010 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001011 raise Exception("OpMul: invalid input dtype")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001012
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001013 for idx, shape in enumerate(shapeList[:]):
1014 low = -(2 ** (num_bits - 1))
1015 high = (2 ** (num_bits - 1)) - 1
1016
1017 a_arr = np.int32(
1018 testGen.rng.integers(low=low, high=high, size=shapeList[0])
1019 )
1020 b_arr = np.int32(
1021 testGen.rng.integers(low=low, high=high, size=shapeList[1])
1022 )
1023
1024 i = 0
1025 while True:
1026
1027 a_arr_64 = a_arr.astype(np.int64)
1028 b_arr_64 = b_arr.astype(np.int64)
1029
1030 if shift > 0:
1031 rounding = 1 << (shift - 1)
1032 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001033 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001034 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001035
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001036 if (result_arr > -(2**31)).all() and (
1037 result_arr <= ((2**31) - 1)
1038 ).all():
1039 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001040
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001041 i = i + 1
1042 a_arr = a_arr // 2
1043 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001044
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001045 tens_ser_list.append(
1046 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001047 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001048 tens_ser_list.append(
1049 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1050 )
1051
1052 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001053
1054 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001055 def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001056 count = len(shapeList) - testGen.args.num_const_inputs_concat
1057 if count < 1:
1058 count = 1
1059 if testGen.args.num_const_inputs_concat == 0:
1060 count = len(shapeList)
1061
1062 # Ensure axis is an int
1063 testArgs[0] = int(testArgs[0])
1064
1065 shapeList = TosaTensorGen.tgConcatConstInput(
1066 testGen, shapeList, testArgs[0], error_name
1067 )
1068
1069 tens = []
1070 tens.extend(
1071 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
1072 )
1073 tens.extend(testGen.buildConstTensors(shapeList[count:], dtypeList[count:]))
1074
1075 return tens
1076
1077 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001078 def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001079 pCount, cCount = op["operands"]
1080 assert (
1081 pCount == 2 and cCount == 0
1082 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
1083 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1084 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
1085 placeholders = []
1086 placeholders.append(
1087 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1088 )
1089 placeholders.append(
1090 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1091 )
1092
1093 return placeholders
1094
1095 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001096 def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001097 if error_name is None:
1098 pCount, cCount = op["operands"]
1099 assert (
1100 pCount == 2 and cCount == 0
1101 ), "Op.EQUAL must have 2 placeholders, 0 consts"
1102 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1103 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
1104 # Using random numbers means that it will be very unlikely that
1105 # there are any matching (equal) values, therefore force that
1106 # there are twice the number of matching values as the tensor rank
1107 for num in range(0, len(shapeList[0]) * 2):
1108 a_index = []
1109 b_index = []
1110 # Choose an index in each axis for the whole shape
1111 for axis in range(0, len(shapeList[0])):
1112 # Index can be up to the largest dimension in both shapes
1113 index = np.int32(
1114 testGen.rng.integers(
1115 0, max(shapeList[0][axis], shapeList[1][axis])
1116 )
1117 )
1118 # Reduce the index down to a shape's dim for broadcasting
1119 a_index.append(min(shapeList[0][axis] - 1, index))
1120 b_index.append(min(shapeList[1][axis] - 1, index))
1121
1122 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1123
1124 placeholders = []
1125 placeholders.append(
1126 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1127 )
1128 placeholders.append(
1129 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1130 )
1131 return placeholders
1132 else:
1133 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001134 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001135 )
1136
1137 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001138 def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001139 if dtypeList[0] == DType.INT32:
1140 pCount, cCount = op["operands"]
1141 assert (
1142 pCount == 1 and cCount == 0
1143 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1144 # Limit values so that the sum cannot exceed the range of an int32 during
1145 # summation of any axis
1146 range_val = int((1 << 31) / max(shapeList[0]))
1147 values_arr = np.int32(
1148 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1149 )
1150 placeholders = []
1151 placeholders.append(
1152 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1153 )
1154 return placeholders
1155 else:
1156 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001157 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001158 )
1159
1160
1161class TosaArgGen:
1162 """Argument generators create exhaustive or random lists of attributes for
1163 operators that take attributes or other parameters.
1164
1165 The return value is a list of (descriptive_name, [arglist]) tuples where
1166 the descriptive_name is appended to the test name and the arglist is expanded
1167 as arguments to the operator build function.
1168 """
1169
1170 def __init__(self):
1171 pass
1172
1173 @staticmethod
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001174 def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001175 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001176 if (
1177 error_name is None
1178 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1179 and gtu.dtypeIsSupportedByCompliance(dtype)
1180 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001181 if dtype in [DType.FP16, DType.FP32, DType.BF16]:
1182 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1183 else:
1184 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1185 else:
1186 # Error test or No data generator types listed - assume random
1187 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1188
1189 # Expand arg list with other data generator types
1190 new_arg_list = []
1191 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001192 for arg_str, args_dict in arg_list:
1193 args_dict["dg_type"] = dg_type
Jeremy Johnson1271c442023-09-05 11:39:26 +01001194 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
1195 # Default test
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001196 new_arg_list.append((arg_str, args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001197
1198 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1199 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001200 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001201 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
1202 print(
Jeremy Johnson51779fd2023-09-12 10:27:43 +01001203 f"Skipping {opName} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001204 )
1205 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001206 # KS is required by all dot product generators
1207 assert "ks" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001208
1209 for s in testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS:
1210 new_arg_str = f"{arg_str}_s{s}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001211 new_args_dict = args_dict.copy()
1212 new_args_dict["s"] = s
1213 new_arg_list.append((new_arg_str, new_args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001214
1215 return new_arg_list
1216
1217 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001218 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1219 """A trivial argument generator for operators that don't take any
1220 non-tensor arguments"""
1221 return [("", [])]
1222
1223 @staticmethod
1224 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1225 """Build the axis argument for operators that take a single axis"""
1226 axes = []
1227 shape = shapeList[0]
1228
1229 if error_name == ErrorIf.AxisSmallerZero:
1230 small_axis = testGen.rng.integers(-5, 0)
1231 axes.append(("axis{}".format(small_axis), [small_axis]))
1232 elif error_name == ErrorIf.AxisLargerRank:
1233 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
1234 axes.append(("axis{}".format(large_axis), [large_axis]))
1235 else:
1236 for a in range(0, len(shape)):
1237 axes.append(("axis{}".format(a), [a]))
1238
1239 return axes
1240
1241 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001242 def _calculate_sparsity(num_tests, sparsity_factor):
1243 sparsity = num_tests // sparsity_factor + 1
1244 # If there are only a small number of tests, just select them all
1245 if sparsity < 13:
1246 sparsity = 1
1247 # To get a variety of parameter combinations sparsity should not be a
1248 # multiple of 2, 3 or 5
1249 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1250 sparsity += 1
1251 return sparsity
1252
1253 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001254 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001255 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001256 arg_list = []
1257
Jeremy Johnson0c716862023-04-13 17:18:19 +01001258 if testGen.args.level8k and error_name is not None:
1259 # Don't produce negative large tests
1260 return arg_list
1261
1262 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001263 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001264 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001265 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001266
Jeremy Johnson1271c442023-09-05 11:39:26 +01001267 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001268
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001269 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01001270 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001271 depthwise = opName.startswith("depthwise")
1272
1273 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001274 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001275 if error_name != ErrorIf.WrongRank:
1276 assert len(ifm_shape) == rank
1277 assert len(filter_shape) == rank
1278
Jeremy Johnson0c716862023-04-13 17:18:19 +01001279 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001280 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001281 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01001282 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001283 # compliance size - KS
1284 k_size = gtu.product(k_shape)
1285 if not depthwise:
1286 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001287
Jeremy Johnson0c716862023-04-13 17:18:19 +01001288 if not testGen.args.level8k:
1289 # Generate comprehensive argument lists
1290 # - except for named errors, which use specific invalid value(s)
1291 if error_name == ErrorIf.PadSmallerZero:
1292 p_vals = [testGen.rng.choice(range(-5, 0))]
1293 else:
1294 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1295 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1296 if error_name == ErrorIf.StrideSmallerOne:
1297 # Can't use stride=0, as it is used to derive output shape, as a divisor
1298 s_vals = [testGen.rng.choice(range(-5, 0))]
1299 else:
1300 # Stride must be greater than 1 to force non-integer error
1301 startStride = (
1302 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001303 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001304 s_vals = [
1305 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1306 ]
1307 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1308 if error_name == ErrorIf.DilationSmallerOne:
1309 d_vals = [testGen.rng.choice(range(-5, 1))]
1310 else:
1311 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1312 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001313
Jeremy Johnson0c716862023-04-13 17:18:19 +01001314 if not error_name and testGen.args.oversize:
1315 # add some oversize argument values
1316 if max(ifm_shape) < 64:
1317 bigPadding = 9
1318 paddings.update(
1319 {
1320 x
1321 for x in itertools.product(
1322 *([[0, bigPadding]] * (k_rank * 2))
1323 )
1324 }
1325 )
1326 bigStride = 8
1327 strides.update(
1328 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1329 )
1330 bigDilation = 7
1331 dilations.update(
1332 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1333 )
1334 max_dim_size = None
1335
1336 # There are too many parameter combinations, so generate them sparsely,
1337 # very sparse for negative tests
1338 sparsity_factor = 2 if error_name else 120
1339 sparsity = TosaArgGen._calculate_sparsity(
1340 len(paddings) * len(strides) * len(dilations), sparsity_factor
1341 )
1342 else:
1343 # Only test 8k levels boundaries
1344 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1345 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1346 bigPadding = bigKernel
1347
1348 dilation_shape = [1] * k_rank
1349 pad_shape = [0] * k_rank * 2
1350 if conv3d:
1351 # Small stride apart from for big kernel (see below) to keep
1352 # tensor size/calculation small
1353 stride_shape = [1] * k_rank
1354 for idx in range(k_rank):
1355 pad_offset = idx * 2
1356 if k_shape[idx] == bigKernel:
1357 # Padding shape needs to account for tensor shape
1358 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1359 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1360 # Big stride to reduce output size
1361 stride_shape[idx] = bigKernel
1362 else:
1363 # Account for kernel size
1364 pad_shape[pad_offset] = k_shape[idx] - 1
1365 else:
1366 # Always have a large stride with extra padding and dilation to keep
1367 # tensor calculation reasonable
1368 stride_shape = [bigKernel] * k_rank
1369 for idx in range(k_rank):
1370 # Dilation shape must account for kernel size
1371 dilation_shape[idx] = bigKernel // k_shape[idx]
1372 # Padding shape needs to accommodate tensor/kernel & dilation
1373 pad_offset = idx * 2
1374 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1375 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1376
1377 strides = {tuple(stride_shape)}
1378 dilations = {tuple(dilation_shape)}
1379 paddings = {tuple(pad_shape)}
1380 # Create a limit for the output dimensions size
1381 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1382
1383 # Currently allow all combinations that are reasonable size
1384 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001385
1386 n = 0
1387 for s in sorted(list(strides)):
1388 for p in sorted(list(paddings)):
1389 for d in sorted(list(dilations)):
1390 if (
1391 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001392 # the padded shape must exceed the dilation * kernel to get a positive
1393 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001394 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
1395 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01001396 and (
1397 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01001398 or (
1399 (ifm_shape[3] - 1 + p[4] + p[5])
1400 > d[2] * (k_shape[2] - 1)
1401 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01001402 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001403 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001404 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01001405 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001406 for index in range(k_rank):
1407 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001408 partial = (
1409 ifm_shape[index + 1]
1410 - 1
1411 + p[pad_offset]
1412 + p[pad_offset + 1]
1413 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001414 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001415 remainders.append(partial % s[index])
1416 outputs.append((partial // s[index]) + 1)
1417
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001418 if (
1419 # the parameters must produce integer exact output
1420 error_name != ErrorIf.ConvOutputShapeNonInteger
1421 and max(remainders) == 0
1422 ) or (
1423 error_name == ErrorIf.ConvOutputShapeNonInteger
1424 and max(remainders) > 0
1425 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001426 if (
1427 max_dim_size is not None
1428 and max(outputs) >= max_dim_size
1429 ):
1430 # Test will consume too much memory - skip it
1431 continue
1432
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001433 # Compliance - number of dot product calculations
1434 if depthwise:
1435 # TODO - add support
1436 dots = 0
1437 else:
1438 dots = gtu.product(
1439 (ifm_shape[0], *outputs, filter_shape[0])
1440 )
1441 args_dict = {
1442 "acc_type": accum_dtype,
1443 "stride": s,
1444 "pad": p,
1445 "dilation": d,
1446 "kernel": k_shape,
1447 "ks": k_size,
1448 "dot_products": dots,
1449 }
1450
Jeremy Johnson0c716862023-04-13 17:18:19 +01001451 # Support for larger values than 9 needs different delimiter
1452 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001453 arg_list.append(
1454 (
James Ward8b390432022-08-12 20:48:56 +01001455 "acc{}_st{}_pad{}_dilat{}".format(
1456 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001457 delim.join([str(x) for x in s]),
1458 delim.join([str(x) for x in p]),
1459 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001460 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001461 args_dict,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001462 )
1463 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001464 n += 1
1465
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001466 arg_list = TosaArgGen._add_data_generators(
1467 testGen,
1468 opName,
1469 dtypes[0],
1470 arg_list,
1471 error_name,
1472 )
1473 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001474 return arg_list
1475
1476 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001477 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1478
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001479 assert isinstance(dtypes, list) or isinstance(
1480 dtypes, tuple
1481 ), f"{dtypes} unexpected"
1482 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01001483
1484 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001485 accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01001486 elif error_name == ErrorIf.WrongInputType:
1487 # Pick some potentially correct output dtype if input type is incorrect
1488 accum_dtype = DType.INT32
1489 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001490 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001491
1492 return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])]
1493
1494 @staticmethod
1495 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
1496 # Get valid accumulate type(s)
1497 if dtype == DType.INT8:
1498 accum_dtypes = [DType.INT32]
1499 elif dtype == DType.INT16:
1500 accum_dtypes = [DType.INT48]
1501 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001502 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001503 elif dtype == DType.BF16:
1504 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001505 elif dtype == DType.FP32:
1506 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001507 elif error_name is None:
1508 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
1509
1510 if error_name == ErrorIf.WrongOutputType:
1511 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson1271c442023-09-05 11:39:26 +01001512 accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01001513 elif error_name == ErrorIf.WrongInputType:
1514 # Pick some potentially correct output dtype if input type is incorrect
1515 accum_dtypes = [DType.INT32]
1516
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001517 # Set up compliance info
1518 args_dict = {
1519 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
1520 # Set dot_products = N*H*W
1521 "dot_products": gtu.product(
1522 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
1523 ),
1524 }
1525
1526 # Create arg tuple of string and dict
1527 arg_list = []
1528 for a in accum_dtypes:
1529 d = args_dict.copy()
1530 d["acc_type"] = a
1531 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001532
1533 arg_list = TosaArgGen._add_data_generators(
1534 testGen,
1535 opName,
1536 dtype,
1537 arg_list,
1538 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001539 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001540 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001541 return arg_list
James Ward8b390432022-08-12 20:48:56 +01001542
1543 @staticmethod
1544 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001545 arg_list = []
1546
Jeremy Johnson0c716862023-04-13 17:18:19 +01001547 if testGen.args.level8k and error_name is not None:
1548 # Don't produce negative large tests
1549 return arg_list
1550
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001551 ifm_shape = shapeList[0]
1552 filter_shape = shapeList[1]
1553
Jeremy Johnson1271c442023-09-05 11:39:26 +01001554 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001555
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001556 # Must be rank 4
1557 if error_name != ErrorIf.WrongRank:
1558 assert len(ifm_shape) == 4
1559 assert len(filter_shape) == 4
1560
Jeremy Johnson0c716862023-04-13 17:18:19 +01001561 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001562
Jeremy Johnson0c716862023-04-13 17:18:19 +01001563 if not testGen.args.level8k:
1564 # Generate comprehensive argument lists
1565 # - except for named errors, which use specific invalid value(s)
1566 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
1567 if error_name == ErrorIf.PadLargerEqualKernel:
1568 max_filter_size = -max(k_shape[0], k_shape[1])
1569 p_vals = [
1570 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
1571 ]
1572 else:
1573 p_vals = [
1574 x
1575 for x in range(
1576 smallest_padding_size, testGen.args.max_conv_padding + 1
1577 )
1578 ]
1579 paddings = {x for x in itertools.product(*([p_vals] * 4))}
1580 if error_name == ErrorIf.StrideSmallerOne:
1581 # Can't use stride=0, as it is used to derive output shape, as a divisor
1582 s_vals = [testGen.rng.choice(range(-5, 0))]
1583 else:
1584 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
1585 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001586
Jeremy Johnson0c716862023-04-13 17:18:19 +01001587 if not error_name and testGen.args.oversize:
1588 # add some oversize argument values
1589 if max(ifm_shape) < 64:
1590 bigPadding = 9
1591 paddings.update(
1592 {
1593 x
1594 for x in itertools.product(
1595 *([[smallest_padding_size, bigPadding]] * 4)
1596 )
1597 }
1598 )
1599 bigStride = 8
1600 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
1601
1602 # There are too many parameter combinations, so generate them sparsely,
1603 # very sparse for negative tests
1604 sparsity_factor = 2 if error_name else 10
1605 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
1606 # If there are only a small number of tests, just select them all
1607 if sparsity < 13:
1608 sparsity = 1
1609 # To get a variety of parameter combinations sparsity should not be a
1610 # multiple of 2, 3 or 5
1611 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1612 sparsity += 1
1613 else:
1614 # Only test 8k levels boundaries
1615 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1616 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1617 bigPadding = bigKernel
1618
1619 pad_shape = [0] * (len(k_shape) * 2)
1620 stride_shape = [1] * len(k_shape)
1621 # The point at which input dimension combined with the stride will
1622 # create large output sizes!
1623 LARGE_SIZE = 2
1624 for idx in range(len(k_shape)):
1625 pad_offset = idx * 2
1626 if k_shape[idx] == bigKernel:
1627 # Set large stride
1628 stride_shape[idx] = bigKernel
1629 # Use negative output padding to reduce shape size
1630 pad_shape[pad_offset] = -(bigPadding - 1)
1631 if ifm_shape[idx + 1] > LARGE_SIZE:
1632 pad_shape[pad_offset + 1] = -(bigPadding - 1)
1633 else:
1634 # The other dimension should be the bigKernel
1635 alt_idx = 1 - idx
1636 if (
1637 k_shape[alt_idx] == bigKernel
1638 and ifm_shape[alt_idx + 1] < LARGE_SIZE
1639 ):
1640 # As the input is small, the large stride won't
1641 # affect the output so we can add some padding
1642 pad_shape[pad_offset + 1] = bigPadding
1643
1644 strides = {tuple(stride_shape)}
1645 paddings = {tuple(pad_shape)}
1646
1647 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001648 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001649
1650 n = 0
1651 for s in sorted(list(strides)):
1652 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07001653 if n % sparsity == 0:
1654 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001655 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
1656 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07001657 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001658
1659 # Support for larger values than 9 needs different delimiter
1660 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07001661 arg_list.append(
1662 (
James Ward8b390432022-08-12 20:48:56 +01001663 "acc{}_st{}_pad{}_os{}".format(
1664 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001665 delim.join([str(x) for x in s]),
1666 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07001667 "x".join([str(x) for x in os]),
1668 ),
James Ward8b390432022-08-12 20:48:56 +01001669 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001670 )
TatWai Chong24594f52022-06-08 00:48:04 -07001671 )
1672 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001673
1674 return arg_list
1675
1676 @staticmethod
1677 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001678 rank = len(shapeList[0])
1679
1680 # Exhaustively test combinations of padding on each side of each dimension
1681 # - the range of padding values is defined by pad_min and pad_max
1682 # - for padding >9, the name format needs to be more distinctive
1683 pad_min, pad_max = 0, 1
1684 pad_values = [x for x in range(pad_min, pad_max + 1)]
1685 if error_name == ErrorIf.PadSmallerZero:
1686 pad_values = [x for x in range(-2, 0)]
1687 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
1688 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
1689
1690 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
1691 pad_const_int = testGen.getRandNumberDType(dtype)
1692 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00001693 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001694 pad_const_int = 0
1695 pad_const_fp = testGen.getRandNumberDType(dtype)
1696 else:
1697 return []
1698
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001699 list_shape_pad_values = list(shape_pad_values)
1700 # If we are producing tests for rank 6 or greater use sparsity
1701 if len(list_shape_pad_values) > 1024:
1702 sparsity_factor = 2 if error_name else 120
1703 sparsity = TosaArgGen._calculate_sparsity(
1704 len(list_shape_pad_values), sparsity_factor
1705 )
1706 else:
1707 sparsity = 1
1708
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001709 # Build arg list
1710 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001711 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01001712 paddings = list(paddings)
1713 args_valid = True
1714
1715 if error_name == ErrorIf.PadSmallerZero:
1716 # Prevent negative output shapes while ensuring still testing for negative padding
1717 for i in range(rank):
1718 dim_after_padding = (
1719 paddings[i][0] + paddings[i][1] + shapeList[0][i]
1720 )
1721 if dim_after_padding < 1:
1722 paddings[i] = (0, 0)
1723 if all([p > -1 for p in paddings[i]]):
1724 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001725 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01001726 name = "pad"
1727 for r in range(rank):
1728 before, after = paddings[r]
1729 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001730 args_dict = {
1731 "pad": np.array(paddings),
1732 "pad_const_int": pad_const_int,
1733 "pad_const_fp": pad_const_fp,
1734 }
1735 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01001736
1737 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
1738 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001739
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001740 arg_list = TosaArgGen._add_data_generators(
1741 testGen,
1742 opName,
1743 dtype,
1744 arg_list,
1745 error_name,
1746 )
1747
1748 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001749 return arg_list
1750
1751 @staticmethod
1752 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
1753 arg_list = []
1754
1755 shape = shapeList[0]
1756 if error_name != ErrorIf.WrongRank:
1757 assert len(shape) == 4
1758
Jeremy Johnson0c716862023-04-13 17:18:19 +01001759 test_level8k = testGen.args.level8k and error_name is None
1760
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001761 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001762 startKernel = 2
1763 startPad = 0
1764 if not test_level8k:
1765 # Generate comprehensive argument lists
1766 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
1767 paddings = {x for x in itertools.product(*([p_vals] * 4))}
1768 # Stride must be greater than 1 to force non-integer error
1769 s_vals = [
1770 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
1771 ]
1772 strides = {x for x in itertools.product(*([s_vals] * 2))}
1773 k_vals = [
1774 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
1775 ]
1776 kernels = {x for x in itertools.product(*([k_vals] * 2))}
1777 max_dim_size = None
1778 else:
1779 # Only test 8k levels
1780 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1781 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1782 strides = {(1, bigStride), (bigStride, 4)}
1783 kernels = {(1, bigKernel), (bigKernel, 3)}
1784 paddings = set()
1785 for s in sorted(list(strides)):
1786 for k in sorted(list(kernels)):
1787 padding = []
1788 for idx in range(len(k)):
1789 total_padding = s[idx] - shape[idx + 1] + k[idx]
1790 while total_padding < 0:
1791 # Must meet: shape + padding > kernel
1792 total_padding += s[idx]
1793 if total_padding < k[idx]:
1794 padding.extend([0, total_padding])
1795 else:
1796 # Note this may produce padding >= k[idx] which is not
1797 # allowed - but will be ignored in the creation loop below
1798 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
1799 paddings.add(tuple(padding))
1800 # Create a limit for the output dimensions size
1801 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001802
James Ward8b390432022-08-12 20:48:56 +01001803 if opName == "max_pool2d":
1804 accum_dtypes = [None] # max_pool has no accumulate dtype
1805 elif dtype == DType.INT8 or dtype == DType.INT16:
1806 accum_dtypes = [DType.INT32]
1807 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001808 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001809 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001810 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001811 elif error_name is None:
1812 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
1813 else:
1814 # Set to something for the ErrorIf case which has
1815 # incorrect input data-type
1816 accum_dtypes = [DType.INT32]
1817
Jeremy Johnson0c716862023-04-13 17:18:19 +01001818 if not test_level8k:
1819 if testGen.args.oversize:
1820 # add some oversize argument values
1821 bigStride = 7
1822 bigKernel = 9
1823 strides.update(
1824 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001825 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001826 kernels.update(
1827 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
1828 )
1829 if max(shape) < 64:
1830 # padding must be less than the kernel size
1831 bigPadding = bigKernel - 1
1832 paddings.update(
1833 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
1834 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001835
Jeremy Johnson0c716862023-04-13 17:18:19 +01001836 # There are too many parameter combinations, so generate them sparsely,
1837 # very sparse for negative tests
1838 sparsity_factor = 2 if error_name else 500
1839 sparsity = (
1840 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
1841 )
1842 else:
1843 # We have already limited test output combinations for 8k tests
1844 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001845
James Ward8b390432022-08-12 20:48:56 +01001846 arg_str = (
1847 "acc{}_st{}_kern{}_pad{}"
1848 if accum_dtypes[0] is not None
1849 else "st{}_kern{}_pad{}"
1850 )
1851
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001852 def get_arg_list_element(accum, stride, pad, kern, dot_products=0):
James Ward8b390432022-08-12 20:48:56 +01001853 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001854 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01001855
1856 # Support for larger values than 9 needs different delimiter
1857 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01001858 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01001859 delim.join([str(x) for x in stride]),
1860 delim.join([str(x) for x in kern]),
1861 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01001862 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001863 args_dict = {
1864 "stride": stride,
1865 "pad": pad,
1866 "kernel": kern,
1867 "dot_products": dot_products, # Ignored for error tests
1868 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
1869 }
James Ward8b390432022-08-12 20:48:56 +01001870
1871 if accum is not None:
1872 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001873 args_dict["acc_type"] = accum
1874 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01001875
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001876 n = 0
James Ward8b390432022-08-12 20:48:56 +01001877 for a in accum_dtypes:
1878 for s in sorted(list(strides)):
1879 for p in sorted(list(paddings)):
1880 for k in sorted(list(kernels)):
1881 if error_name in [
1882 ErrorIf.StrideSmallerOne,
1883 ErrorIf.KernelSmallerOne,
1884 ErrorIf.PadSmallerZero,
1885 ErrorIf.PadLargerEqualKernel,
1886 ]:
1887 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
1888 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001889 )
James Ward8b390432022-08-12 20:48:56 +01001890 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001891 arg_list.append(
1892 get_arg_list_element(a, sNew, pNew, kNew)
1893 )
James Ward8b390432022-08-12 20:48:56 +01001894 elif (
1895 n % sparsity == 0
1896 # padding must not exceed the kernel size
1897 and p[0] < k[0]
1898 and p[1] < k[0]
1899 and p[2] < k[1]
1900 and p[3] < k[1]
1901 # the padded shape must exceed the kernel size
1902 and (shape[1] + p[0] + p[1]) > k[0]
1903 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001904 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001905 partial_h = shape[1] + p[0] + p[1] - k[0]
1906 partial_w = shape[2] + p[2] + p[3] - k[1]
1907 remainder_h = partial_h % s[0]
1908 remainder_w = partial_w % s[1]
1909 output_h = partial_h // s[0] + 1
1910 output_w = partial_w // s[1] + 1
1911 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01001912 if (
1913 # the parameters must produce integer exact output
1914 error_name != ErrorIf.PoolingOutputShapeNonInteger
1915 and remainder_h == 0
1916 and remainder_w == 0
1917 ) or (
1918 error_name == ErrorIf.PoolingOutputShapeNonInteger
1919 and (remainder_h != 0 or remainder_w != 0)
1920 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001921 if (
1922 max_dim_size is not None
1923 and max(output_h, output_w) > max_dim_size
1924 ):
1925 # Test will consume too much memory - skip it
1926 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001927 # Dot products = N*OH*OW*C
1928 dp = gtu.product(
1929 (shape[0], output_h, output_w, shape[3])
1930 )
1931 arg_list.append(get_arg_list_element(a, s, p, k, dp))
James Ward8b390432022-08-12 20:48:56 +01001932 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001933
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001934 # Now add data generator types
1935 arg_list = TosaArgGen._add_data_generators(
1936 testGen,
1937 opName,
1938 dtype,
1939 arg_list,
1940 error_name,
1941 )
1942
1943 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001944 return arg_list
1945
1946 @staticmethod
1947 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
1948 arg_list = []
1949
1950 # Enumerate the output types here
1951 if error_name == ErrorIf.WrongOutputType:
1952 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
1953 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00001954 dtypeList = [
1955 DType.BOOL,
1956 DType.INT16,
1957 DType.INT32,
1958 DType.FP16,
1959 DType.BF16,
1960 DType.FP32,
1961 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001962 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00001963 dtypeList = [
1964 DType.BOOL,
1965 DType.INT8,
1966 DType.INT32,
1967 DType.FP16,
1968 DType.BF16,
1969 DType.FP32,
1970 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001971 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00001972 dtypeList = [
1973 DType.BOOL,
1974 DType.INT8,
1975 DType.INT16,
1976 DType.FP16,
1977 DType.BF16,
1978 DType.FP32,
1979 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001980 elif inDtype == DType.BOOL:
1981 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01001982 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00001983 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001984 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00001985 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001986 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00001987 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001988 elif error_name == ErrorIf.WrongInputType:
1989 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001990 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001991 else:
1992 raise Exception("Unexpected input dtype: {}".format(inDtype))
1993
1994 for dtype in dtypeList:
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001995 arg_list.append(("out{}".format(testGen.typeStr(dtype)), [dtype]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001996
1997 return arg_list
1998
1999 @staticmethod
2000 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
2001 arg_list = []
2002
2003 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002004 for outDtype in [
2005 DType.UINT8,
2006 DType.INT8,
2007 DType.INT16,
2008 DType.INT32,
2009 DType.UINT16,
2010 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002011 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002012 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002013 and error_name == ErrorIf.OutputZeroPointNotZero
2014 ):
2015 continue
2016 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002017 outDtype != DType.UINT16
2018 and error_name == ErrorIf.U16OutputZeroPointNotValid
2019 ) or (
2020 inDtype != DType.UINT16
2021 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002022 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002023 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002024 continue
2025 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002026 inDtype == DType.UINT8
2027 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002028 and error_name != ErrorIf.WrongOutputType
2029 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002030 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2031 continue
2032 if (
2033 inDtype not in [DType.INT8, DType.INT16]
2034 and outDtype == DType.UINT8
2035 and error_name != ErrorIf.WrongOutputType
2036 ):
2037 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2038 continue
2039 if (
2040 inDtype == DType.UINT16
2041 and outDtype != DType.INT16
2042 and error_name != ErrorIf.WrongOutputType
2043 ):
2044 # The only output dtype for UINT16 is INT16, skip all others
2045 continue
2046 if (
2047 inDtype != DType.INT16
2048 and outDtype == DType.UINT16
2049 and error_name != ErrorIf.WrongOutputType
2050 ):
2051 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002052 continue
2053 if (
2054 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002055 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002056 ):
2057 continue
2058
2059 for scale32 in [False, True]:
2060 if error_name == ErrorIf.ScaleTrue and not scale32:
2061 continue
2062 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2063 continue
2064 for double_round in [False, True]:
2065 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2066 continue
2067 for per_channel in [False, True]:
2068
2069 if (
2070 inDtype == DType.INT48
2071 and scale32
2072 and error_name != ErrorIf.ScaleTrue
2073 ):
2074 # Illegal condition. Must be scale32=False
2075 continue
2076 if (
2077 double_round
2078 and not scale32
2079 and error_name != ErrorIf.ScaleNotTrue
2080 ):
2081 # Illegal condition. ERROR_IF(!scale32 && double_round)
2082 continue
2083
2084 arg_list.append(
2085 (
2086 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002087 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002088 int(scale32),
2089 int(double_round),
2090 int(per_channel),
2091 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002092 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002093 )
2094 )
2095
2096 return arg_list
2097
2098 @staticmethod
2099 def agMul(testGen, opName, shapeList, dtype, error_name=None):
2100 arg_list = []
2101
2102 if dtype is DType.INT32:
2103 for p in range(testGen.args.num_rand_permutations):
2104
2105 shift = testGen.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002106 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002107 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002108 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002109
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002110 arg_list = TosaArgGen._add_data_generators(
2111 testGen,
2112 opName,
2113 dtype,
2114 arg_list,
2115 error_name,
2116 )
2117 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002118 return arg_list
2119
2120 @staticmethod
2121 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
2122 arg_list = []
2123
2124 arg_list.append(("roundTrue", [True]))
2125 arg_list.append(("roundFalse", [False]))
2126
2127 return arg_list
2128
Luke Hutton57287132023-02-06 14:54:18 +00002129 @staticmethod
2130 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2131 arg_list = []
2132
2133 arg_list.append(("inverseTrue", [True]))
2134 arg_list.append(("inverseFalse", [False]))
2135
2136 return arg_list
2137
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002138 # Helper function for reshape. Gets some factors of a larger number.
2139 @staticmethod
2140 def getFactors(val, start=1):
2141 factors = []
2142
2143 for i in range(start, int(np.sqrt(val)) + 1):
2144 if (val % i) == 0:
2145 factors.append(i)
2146
2147 return factors
2148
2149 @staticmethod
2150 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
2151 arg_list = []
2152
2153 origShape = shapeList[0]
2154
2155 totalElements = 1
2156 for s in origShape:
2157 totalElements *= s
2158
2159 # This code is NOT fast. Fortunately, the numbers are fairly small.
2160 factors = TosaArgGen.getFactors(totalElements)
2161
2162 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002163 # Rank from 1 to TOSA_TENSOR_MAX_RANK
2164 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002165 if len(factors) < newRank:
2166 continue
2167
2168 found = True
2169 # escape_counter breaks while loop if it continues on for too long
2170 escape_counter = 0
2171 while found:
2172 newShape = []
Jerry Ge264f7fa2023-04-21 22:49:57 +00002173 new_shape_inferred = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002174 # Generate newShape ensuring it isn't a duplicate
2175 remainingElements = totalElements
2176 shuffledFactors = testGen.rng.permutation(factors)
Jerry Ge264f7fa2023-04-21 22:49:57 +00002177 inferred_dim = testGen.rng.integers(1, newRank + 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002178 for i in range(1, newRank):
2179 # pick rank-1 factors
2180 newShape.append(shuffledFactors[0])
2181 remainingElements = remainingElements // shuffledFactors[0]
Jerry Ge264f7fa2023-04-21 22:49:57 +00002182 if i == inferred_dim:
2183 new_shape_inferred.append(-1)
2184 else:
2185 new_shape_inferred.append(shuffledFactors[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002186 shuffledFactors = testGen.rng.permutation(
2187 TosaArgGen.getFactors(remainingElements)
2188 )
2189 newShape.append(remainingElements)
Jerry Ge264f7fa2023-04-21 22:49:57 +00002190 if inferred_dim == newRank:
2191 new_shape_inferred.append(-1)
2192 else:
2193 new_shape_inferred.append(remainingElements)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002194
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002195 # Check for duplicates
2196 found = False
2197 for name, other_shape in arg_list:
2198 if other_shape[0] == newShape:
2199 found = True
2200 break
2201
2202 escape_counter += 1
2203 if escape_counter >= 100:
2204 break
2205
2206 if not found:
Jerry Ge264f7fa2023-04-21 22:49:57 +00002207 if error_name in [
2208 ErrorIf.ReshapeOutputSizeNonInteger,
2209 ErrorIf.ReshapeOutputSizeMultiInference,
2210 ]:
2211 if newRank < 2:
2212 # Need at least two dimensions
2213 continue
2214 # NOTE: Change inferred_dim starting offset from 1 to 0
2215 inferred_dim -= 1
2216 extra_dim = inferred_dim + testGen.rng.integers(1, newRank)
2217 extra_dim = extra_dim % newRank
2218 assert extra_dim != inferred_dim
2219 if error_name == ErrorIf.ReshapeOutputSizeNonInteger:
2220 elements = 1
2221 for i, dim_value in enumerate(new_shape_inferred):
2222 if i != inferred_dim and i != extra_dim:
2223 elements *= dim_value
2224 dim_value = new_shape_inferred[extra_dim]
2225 while totalElements % (elements * dim_value) == 0:
2226 dim_value += 1
2227 new_shape_inferred[extra_dim] = dim_value
2228 else:
2229 assert error_name == ErrorIf.ReshapeOutputSizeMultiInference
2230 new_shape_inferred[extra_dim] = -1
2231 else:
2232 arg_list.append(
2233 ("perm{}_rank{}_outdefined".format(p, newRank), [newShape])
2234 )
2235 if error_name != ErrorIf.TensorSizeInputOutputMismatch:
2236 arg_list.append(
2237 (
2238 "perm{}_rank{}_outinferred".format(p, newRank),
2239 [new_shape_inferred],
2240 )
2241 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002242
2243 return arg_list
2244
2245 @staticmethod
2246 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
2247 arg_list = []
2248
2249 ifm_shape = shapeList[0]
2250
2251 if error_name == ErrorIf.IndexOutsideBounds:
2252 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
2253 incorrect_small_index = range(-len(ifm_shape), 0)
2254 permutations = [p for p in itertools.permutations(incorrect_large_index)]
2255 permutations.extend(
2256 [p for p in itertools.permutations(incorrect_small_index)]
2257 )
2258 elif error_name == ErrorIf.IndexUsedTwice:
2259 # Create list with a duplicated index
2260 perm_range = list(range(len(ifm_shape)))
2261 index_choice = testGen.rng.choice(range(len(perm_range)))
2262 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
2263 permutations = [p for p in itertools.permutations(perm_range)]
2264
2265 else:
2266 # Get all permutations
2267 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
2268
2269 # Limit to possible permutations from shape dimension or argument setting
2270 limit = min(len(permutations), testGen.args.num_rand_permutations)
2271
2272 # Get random permutation generator that uses all permutations
2273 random_permutations = testGen.rng.permutation(permutations)
2274
2275 # Create list of required amount of permutations
2276 arg_list = [
2277 ("perm{}".format(p), [random_permutations[p].tolist()])
2278 for p in range(limit)
2279 ]
2280 return arg_list
2281
2282 @staticmethod
2283 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
2284 arg_list = []
2285
2286 ifm_shape = shapeList[0]
2287 rank = len(ifm_shape)
2288
2289 for p in range(testGen.args.num_rand_permutations):
2290 start = []
2291 size = []
2292
2293 valid = True
2294
2295 for i in range(rank):
2296 if ifm_shape[i] > 1:
2297 start.append(testGen.randInt(0, ifm_shape[i]))
2298 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
2299
2300 # Invalid slice size?
2301 if size[i] == 0:
2302 valid = False
2303 else:
2304 start.append(0)
2305 size.append(1)
2306
2307 if valid:
2308 # If ERROR_IF test required then incorrect start, size will be returned
2309 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
2310 testGen, error_name, ifm_shape, start, size
2311 )
2312 arg_list.append(("perm{}".format(p), [start, size]))
2313 return arg_list
2314
2315 @staticmethod
2316 def agTile(testGen, opName, shapeList, dtype, error_name=None):
2317 arg_list = []
2318
2319 ifm_shape = shapeList[0]
2320 rank = len(ifm_shape)
2321
2322 for p in range(testGen.args.num_rand_permutations):
2323
2324 # Pick a few random, but small multiple values
2325 # because otherwise this has a tendency to generate
2326 # enormous tensors
2327 multiples = []
2328 for i in range(rank):
2329 if ifm_shape[i] > 1000:
2330 # Multiple of 1 if ifm_shape dimension is large to reduce
2331 # tensor size
2332 multiples.append(1)
2333 elif max(ifm_shape) > 1000:
2334 multiples.append(2)
2335 else:
2336 multiples.append(testGen.randInt(1, 4))
2337 arg_list.append(("perm{}".format(p), [multiples]))
2338
2339 return arg_list
2340
2341 @staticmethod
2342 def agResize(testGen, opName, shapeList, dtype, error_name=None):
2343 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002344 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002345
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002346 def get_aspect_ratio_resize_params():
2347 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
2348 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
2349 invert = testGen.rng.choice((False, True))
2350 letterbox = testGen.rng.choice((False, True))
2351
2352 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
2353 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
2354 scale_y_d = scale_x_d = 1
2355 offset_x = offset_y = 0
2356
2357 if letterbox:
2358 max_border = scale_y_n
2359 border_y = testGen.randInt(low=0, high=max_border)
2360 border_x = 0
2361 else:
2362 # Pillarboxing
2363 border_y = 0
2364 max_border = scale_x_n
2365 border_x = testGen.randInt(low=0, high=max_border)
2366
2367 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2368 offset = (offset_y, offset_x)
2369 border = (border_y, border_x)
2370
2371 return scale, offset, border
2372
2373 def get_upscale_downscale_params():
2374 valid_params = False
2375 while not valid_params:
2376 upscale = testGen.rng.choice((False, True))
2377
2378 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
2379 origin_sampling = testGen.rng.choice((False, True))
2380
2381 if upscale:
2382 shift = testGen.randInt(low=1, high=4)
2383 scale_x_d = scale_y_d = 1
2384 scale_x_n = scale_y_n = (
2385 1 << shift if origin_sampling else 2 << shift
2386 )
2387 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
2388 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
2389 else:
2390 scale_x_n = 1
2391 scale_y_n = 1
2392
2393 # Return list of valid scale_*_d values (max value 4) given input dim shape
2394 def get_valid_denom(ifm_dim):
2395 return [x for x in range(1, 5) if ifm_dim % x == 1]
2396
2397 # Generate list of valid downscale values and choose one randomly
2398 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
2399 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
2400
2401 if not valid_scale_y_ds and not valid_scale_x_ds:
2402 # Bad parameters, skip
2403 continue
2404
2405 if not valid_scale_y_ds:
2406 scale_y_d = 1
2407 else:
2408 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
2409
2410 if not valid_scale_x_ds:
2411 scale_x_d = 1
2412 else:
2413 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
2414
2415 border_x = border_y = 0
2416 offset_y = testGen.randInt(0, 16 * scale_y_n)
2417 offset_x = testGen.randInt(0, 16 * scale_x_n)
2418 valid_params = True
2419
2420 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2421 offset = (offset_y, offset_x)
2422 border = (border_y, border_x)
2423 return scale, offset, border
2424
2425 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002426 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
2427 scale = scale_n / scale_d
2428 if scale > max_scale:
2429 factor = scale / max_scale
2430 new_scale_d = math.ceil(scale_d * factor)
2431 assert scale_n / new_scale_d <= max_scale
2432 scale_d = new_scale_d
2433 return scale_d
2434
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002435 # Scale
2436 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
2437 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
2438
2439 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
2440 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
2441
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002442 scale_y_d = fix_scale_to_max_scale(
2443 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2444 )
2445 scale_x_d = fix_scale_to_max_scale(
2446 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2447 )
2448
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002449 # Offsets and border within the scale
2450 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
2451 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
2452 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
2453 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
2454
2455 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2456 offset = (offset_y, offset_x)
2457 border = (border_y, border_x)
2458 return scale, offset, border
2459
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002460 def get_level_8k_params():
2461 # Create 64x scale - 64/1 to 2048/32
2462 scale_d = testGen.randInt(
2463 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
2464 )
2465 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
2466 # Create half to fifth scaling
2467 scale_d_alt = testGen.randInt(low=2, high=6)
2468 scale_n_alt = 1
2469 switch = testGen.rng.choice((False, True))
2470 if switch:
2471 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
2472 else:
2473 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
2474
2475 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
2476 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
2477 offset = (offset_y, offset_x)
2478 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
2479 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
2480 border = (border_y, border_x)
2481 return scale, offset, border
2482
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002483 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002484 # Exclude illegal {mode, type} configurations. Pick legal output types
2485 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
2486 outputDTypeList = [DType.INT8]
2487 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
2488 outputDTypeList = [DType.INT16]
2489 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
2490 outputDTypeList = [DType.INT32]
2491 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
2492 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01002493 elif dtype == DType.FP16:
2494 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01002495 elif dtype == DType.BF16:
2496 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002497 elif dtype == DType.FP32:
2498 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002499 elif error_name == ErrorIf.WrongInputType:
2500 # If an incorrect input type is used then we set a 'correct'
2501 # output type to avoid other errors
2502 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
2503 else:
2504 continue
2505
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002506 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
2507
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002508 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002509 perm = 0
2510 while perm < testGen.args.num_rand_permutations:
2511 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002512 if not testGen.args.level8k:
2513 _rnd_param_fn = testGen.rng.choice(
2514 (
2515 get_rand_params,
2516 get_upscale_downscale_params,
2517 get_aspect_ratio_resize_params,
2518 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002519 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002520 scale, offset, border = _rnd_param_fn()
2521 else:
2522 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002523
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002524 # Expand params for bounds-checking
2525 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
2526 (offset_y, offset_x) = offset
2527 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002528
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002529 # Make sure output dimensions OH and OW are integers
2530 partial_output_y = (
2531 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
2532 )
2533 partial_output_x = (
2534 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
2535 )
2536 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002537 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002538 if (
2539 partial_output_y % scale_y_d == 0
2540 and partial_output_x % scale_x_d == 0
2541 ):
2542 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002543 if perm > 0:
2544 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002545 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002546 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002547 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002548 while partial_output_y % scale_y_d != 0:
2549 scale_y_d -= 1
2550 while partial_output_x % scale_x_d != 0:
2551 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002552 # Make sure we are still within max scaling
2553 if (
2554 scale_y_n / scale_y_d
2555 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
2556 scale_x_n / scale_x_d
2557 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
2558 # Skip the test as it is using too large a scaling factor
2559 if perm > 0:
2560 perm += 1
2561 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002562
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002563 output_y = partial_output_y // scale_y_d + 1
2564 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002565
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002566 if (
2567 output_y >= testGen.args.max_resize_output_dim
2568 or output_x >= testGen.args.max_resize_output_dim
2569 ) and error_name is None:
2570 # Skip positive test if output dim will be too high
2571 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002572 if not testGen.args.level8k or perm > 0:
2573 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002574 continue
2575
2576 if (
2577 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01002578 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002579 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01002580 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002581 ):
2582 # Output dimensions out of scope
2583 if error_name is not None and perm > 0:
2584 # As long as we have one ERROR_IF test, don't worry
2585 # about creating all the other permutations
2586 perm += 1
2587 continue
2588
2589 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
2590 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01002591 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002592 and output_y - scale_y_d < 1
2593 )
2594 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01002595 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002596 and output_x - scale_x_d < 1
2597 )
2598 ):
2599 # Can't create a negative test with these params as it
2600 # will create invalid output size
2601 if perm > 0:
2602 perm += 1
2603 continue
2604
2605 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
2606 offset = [offset_y, offset_x]
2607 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002608
2609 # Common for all data types
2610 if error_name is not None:
2611 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002612 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002613 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002614 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002615 outputDTypeNew,
2616 ) = TosaErrorIfArgGen.eiResizeErrorIf(
2617 testGen,
2618 error_name,
2619 mode,
2620 dtype,
2621 shapeList,
2622 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002623 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002624 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002625 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002626 )
2627 else:
2628 outputDTypeNew = outputDType
2629
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002630 arg_to_append = (
2631 arg_str.format(
2632 "N" if mode == ResizeMode.NEAREST else "B",
2633 testGen.typeStr(outputDTypeNew),
2634 scale[0],
2635 scale[1],
2636 scale[2],
2637 scale[3],
2638 offset[0],
2639 offset[1],
2640 border[0],
2641 border[1],
2642 ),
2643 [
2644 mode,
2645 scale,
2646 offset,
2647 border,
2648 dtype,
2649 outputDTypeNew,
2650 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002651 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002652 if arg_to_append in arg_list:
2653 # Skip already generated test params
2654 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002655
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002656 # Valid permutation
2657 perm += 1
2658 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002659 return arg_list
2660
2661 @staticmethod
2662 def agTable(testGen, opName, shapeList, dtype, error_name=None):
2663 arg_list = []
2664
2665 if dtype == DType.INT8:
2666 table = np.int32(
2667 testGen.rng.integers(low=-128, high=128, size=[256])
2668 ).tolist()
2669 else: # INT16
2670 table = np.int32(
2671 testGen.rng.integers(low=-32768, high=32768, size=[513])
2672 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07002673 # Make sure all slopes are within REQUIRE min/max 16-bit int
2674 for idx in range(len(table) - 1):
2675 slope = table[idx + 1] - table[idx]
2676 # Alter the next table entry to force the slope to be ok
2677 if slope > 32767:
2678 table[idx + 1] -= slope - 32767
2679 if slope < -32768:
2680 table[idx + 1] -= slope + 32768
2681 slope = table[idx + 1] - table[idx]
2682 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002683 arg_list.append(
2684 (
2685 "",
2686 [table],
2687 )
2688 )
2689 return arg_list
2690
2691 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
2692 # CondIf generates the condition values here.
2693 # Convert to tensors in the build function, along with the
2694 # then and else blocks
2695 arg_list = []
2696
2697 for c in [False, True]:
2698 arg_list.append(("cond{}".format(int(c)), [c]))
2699
2700 return arg_list
2701
2702 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
2703 # While loop: 0 iterations, 1, more than 1
2704 arg_list = []
2705
2706 for iter in [0, 1, 4]:
2707 arg_list.append(("iter{}".format(iter), [iter]))
2708
2709 return arg_list