blob: 193da73d0129c1872aed3019f4d5097eada66f2e [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
19
20class TosaQuantGen:
21 """QuantizedInfo random generator helper functions.
22
23 Specify with 'qgen': in the operator defintion.
24 """
25
26 def __init__(self):
27 pass
28
29 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010031
32 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010033 if testGen.args.zeropoint is not None:
34 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035 return testGen.randInt(-128, 128)
36 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010037 if testGen.args.zeropoint is not None:
38 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 return testGen.randInt(0, 256)
40 elif error_name in [
41 ErrorIf.InputZeroPointNotZero,
42 ErrorIf.WeightZeroPointNotZero,
43 ErrorIf.OutputZeroPointNotZero,
44 ]:
45 zero_point = testGen.randInt(-128, 128)
46 if zero_point == 0:
47 zero_point = 1
48 return zero_point
49 return 0
50
51 @staticmethod
52 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010053 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000054 qinfo = [
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 TosaQuantGen.getZeroPoint(testGen, dtype),
57 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010058 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000059 qinfo = [
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
62 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010063 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000064 qinfo = [
65 TosaQuantGen.getZeroPoint(testGen, dtype),
66 TosaQuantGen.getZeroPoint(testGen, dtype),
67 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068 return qinfo
69
70 @staticmethod
71 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010072 if isinstance(dtype_or_dtypeList, list):
73 # a list of [input, weights, accumulator] dtypes
74 dtypeList = dtype_or_dtypeList
75 else:
76 # an int, [input, weights, accumulator] dtypes are the same
77 dtypeList = [dtype_or_dtypeList] * 3
78
79 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000080 qinfo = [
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
82 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
83 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000085 qinfo = [
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
87 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
88 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000090 qinfo = [
91 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
92 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
93 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 return qinfo
95
96 @staticmethod
97 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000099 qinfo = [
100 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
101 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
102 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000104 qinfo = [
105 TosaQuantGen.getZeroPoint(testGen, dtype),
106 TosaQuantGen.getZeroPoint(testGen, dtype),
107 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108 return qinfo
109
110 @staticmethod
111 def computeMultiplierAndShift(scaleFp, scale32):
112 # Derived from computeMultiplierAndShiftTosaScale32
113 # Provide a floating-point scaling factor and the scale32 parameter
114 # to compute the multiplier and shift
115
116 if scale32:
117 scaleBits = 31
118 else:
119 scaleBits = 15
120
121 m, shift = math.frexp(scaleFp)
122
123 if scaleFp < 0.0:
124 m = -m
125
126 multiplier = round(m * (1 << scaleBits))
127 assert multiplier <= (1 << scaleBits)
128
129 if multiplier == (1 << scaleBits):
130 multiplier = multiplier // 2
131 shift = shift + 1
132
133 shift = (-shift) + scaleBits
134 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
135 # scaleFp, scaleBits, m, multiplier, shift))
136
137 # Adjust multiplier such that shift is in allowed value range.
138 if shift == 0:
139 multiplier = multiplier // 4
140 shift = shift + 2
141 elif shift == 1:
142 multiplier = multiplier // 2
143 shift = shift + 1
144 elif shift == 63:
145 multiplier = multiplier * 2
146 shift = shift - 1
147
148 assert multiplier <= (1 << scaleBits)
149 assert shift >= 2 and shift <= 62
150
151 return multiplier, shift
152
153
154class TosaTensorGen:
155 """Tensor generators create a shape list for the placeholder and const tensor
156 data operands for the operator.
157
158 The actual random data is generated separately for each test.
159 """
160
161 def __init__(self):
162 pass
163
164 @staticmethod
165 def tgBasic(testGen, opName, rank, error_name=None):
166 pl, const = opName["operands"]
167 shape = testGen.makeShape(rank)
168
169 # Constrict the overall size of the shape when creating ERROR_IF tests
170 if error_name:
171 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
172
173 shape_list = []
174 for i in range(pl + const):
175 shape_list.append(shape.copy())
176
Luke Huttona4e48ca2023-02-22 11:53:48 +0000177 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 if error_name == ErrorIf.RankMismatch:
179 if rank == 1 and i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
181 elif i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
183
184 return shape_list
185
186 @staticmethod
187 def tgNHWC(testGen, opName, rank, error_name=None):
188 pl, const = opName["operands"]
189
190 if error_name != ErrorIf.WrongRank:
191 assert rank == 4
192
193 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000194 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name and error_name != ErrorIf.MaxDimExceeded:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
207 def tgScatter(testGen, opName, rank, error_name=None):
208 pl, const = opName["operands"]
209
210 assert pl == 2
211 assert const == 0
212 if error_name != ErrorIf.WrongRank:
213 assert rank == 3
214
215 values_in_shape = testGen.makeShape(rank)
216
217 # ignore max batch size if target shape is set
218 if testGen.args.max_batch_size and not testGen.args.target_shapes:
James Ward30124a82023-02-02 14:56:33 +0000219 values_in_shape[0] = min(values_in_shape[0], testGen.args.max_batch_size)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100220
221 W = testGen.randInt(
222 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
223 )
224 # Constrict W if one dimension is too large to keep tensor size reasonable
225 if max(values_in_shape) > 5000:
226 W = testGen.randInt(0, 16)
227
228 input_shape = [values_in_shape[0], W, values_in_shape[2]]
229
230 shape_list = []
231 shape_list.append(values_in_shape.copy())
232 shape_list.append(input_shape.copy())
233
234 return shape_list
235
236 @staticmethod
237 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
238 shape = testGen.makeShape(rank)
239
240 pl, const = op["operands"]
241
242 shape_list = []
243
244 # Choose one of the inputs to broadcast
245 # Note: Simplifies OutputShaper code if we don't change first shape for errors
246 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Jerry Ge135c9552023-05-23 20:59:32 +0000247 fuzz_idx = testGen.randInt(0, rank)
248
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100249 for i in range(pl + const):
250 shape_bcast = shape.copy()
251
Jerry Ge135c9552023-05-23 20:59:32 +0000252 # To test broadcasting, the chosen fuzz index dimension should not be 1
253 if shape_bcast[fuzz_idx] == 1:
254 shape_bcast[fuzz_idx] += 1
255
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100256 # If the chosen input, pick a random index to broadcast
257 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000258 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100259 # Add one rank to the shape (or more for rank of 1)
260 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
261 shape_bcast = np.concatenate(
262 (shape_bcast, testGen.makeShape(extra_ranks))
263 )
264 if rank != 1:
265 # Either keep the extra rank, or remove it
266 new_len = testGen.rng.choice([-2, len(shape_bcast)])
267 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000268 elif error_name == ErrorIf.BroadcastShapesMismatch:
269 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100270 else:
271 shape_bcast[fuzz_idx] = 1
272
273 shape_list.append(shape_bcast)
274
275 return shape_list
276
277 @staticmethod
278 def tgConv2D(testGen, op, rank, error_name=None):
279 pl, const = op["operands"]
280
281 if error_name != ErrorIf.WrongRank:
282 assert rank == 4
283
284 # IFM dimensions are NHWC
285 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000286 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100287
288 # Constrict the overall size of the shape when creating ERROR_IF tests
289 if error_name:
290 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
291 ifm_shape, max_dim=24, max_items=10000
292 )
293
294 # Get the filter height/width from the operator parameters
295 filter_hw = op["filter"]
296
297 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000298 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100299
300 # The filter dimensions are OHWI
301 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
302
303 # The bias is OC
304 bias_shape = np.asarray([ofm_depth])
305
306 return [ifm_shape, filter_shape, bias_shape]
307
308 @staticmethod
309 def tgConv3D(testGen, op, rank, error_name=None):
310 pl, const = op["operands"]
311
312 if error_name != ErrorIf.WrongRank:
313 assert rank == 5
314
315 # IFM dimensions are NDHWC
316 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000317 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100318
319 # Constrict the overall size of the shape when creating ERROR_IF tests
320 if error_name:
321 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
322 ifm_shape, max_dim=24, max_items=10000
323 )
324
325 # Get the filter depth/height/width from the operator parameters
326 filter_dhw = op["filter"]
327
328 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000329 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100330
331 # The filter dimensions are ODHWI
332 filter_shape = np.asarray(
333 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
334 )
335
336 # The bias is OC
337 bias_shape = np.asarray([ofm_channel])
338
339 return [ifm_shape, filter_shape, bias_shape]
340
341 @staticmethod
342 def tgTransposeConv2D(testGen, op, rank, error_name=None):
343 pl, const = op["operands"]
344
345 if error_name != ErrorIf.WrongRank:
346 assert rank == 4
347
348 # IFM dimensions are NHWC
349 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000350 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100351
352 # Constrict the overall size of the shape when creating ERROR_IF tests
353 if error_name:
354 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
355 ifm_shape, max_dim=24, max_items=10000
356 )
357
358 # Get the filter height/width from the operator parameters
359 filter_hw = op["filter"]
360
361 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000362 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100363
364 # The filter dimensions are OHWI
365 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
366
367 # The bias is OC
368 bias_shape = np.asarray([ofm_depth])
369
370 return [ifm_shape, filter_shape, bias_shape]
371
372 @staticmethod
373 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
374 pl, const = op["operands"]
375
376 if error_name != ErrorIf.WrongRank:
377 assert rank == 4
378 assert pl == 1 and const == 2
379
380 # IFM dimensions are NHWC
381 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000382 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100383
384 # Constrict the overall size of the shape when creating ERROR_IF tests
385 if error_name:
386 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
387 ifm_shape, max_dim=24, max_items=10000
388 )
389
390 # Get the filter height/width from the operator parameters
391 # Filter is KH, HW, C, M
392 filter_hw = op["filter"]
393
394 # Generate a random OFM depth, but don't let it get too big because
395 # the output depth is M * C
396 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000397 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100398 ) + 1
399
400 # The filter dimensions are HWCM
401 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
402
403 # The bias is M * C
404 bias_shape = np.asarray([ifm_shape[3] * filter_m])
405
406 return [ifm_shape, filter_shape, bias_shape]
407
408 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000409 def tgFFT2d(testGen, op, rank, error_name=None):
410 pl, const = op["operands"]
411
412 if error_name != ErrorIf.WrongRank:
413 assert rank == 3
414 assert pl == 2 and const == 0
415
416 # IFM dimensions are NHW
417 ifm_shape = testGen.makeShape(rank)
418
419 # Select nearest lower power of two from input height and width
420 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
421 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
422
423 # Constrict the overall size of the shape when creating ERROR_IF tests
424 if error_name:
425 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
426
427 # Generate an invalid kernel that is not a power of two
428 if error_name == ErrorIf.KernelNotPowerOfTwo:
429 inc_h = 2 if ifm_shape[1] == 1 else 1
430 inc_w = 2 if ifm_shape[2] == 1 else 1
431 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
432 selected_inc = testGen.rng.choice(inc_choices)
433 ifm_shape[1] += selected_inc[0]
434 ifm_shape[2] += selected_inc[1]
435
436 ifm_shape = testGen.constrictBatchSize(ifm_shape)
437
438 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
439 if error_name == ErrorIf.FFTInputShapeMismatch:
440 modify_shape = testGen.rng.choice([0, 1])
441 # Only modify kernel (H, W)
442 modify_dim = testGen.rng.choice([1, 2])
443 ifm_shapes[modify_shape][modify_dim] *= 2
444
445 return [ifm_shapes[0], ifm_shapes[1]]
446
447 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000448 def tgRFFT2d(testGen, op, rank, error_name=None):
449 pl, const = op["operands"]
450
451 if error_name != ErrorIf.WrongRank:
452 assert rank == 3
453 assert pl == 1 and const == 0
454
455 # IFM dimensions are NHW
456 ifm_shape = testGen.makeShape(rank)
457
458 # Select nearest lower power of two from input height and width
459 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
460 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
461
462 # Constrict the overall size of the shape when creating ERROR_IF tests
463 if error_name:
464 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
465
466 # Generate an invalid kernel that is not a power of two
467 if error_name == ErrorIf.KernelNotPowerOfTwo:
468 # We must increment by 2 if current size is 1
469 inc_h = 2 if ifm_shape[1] == 1 else 1
470 inc_w = 2 if ifm_shape[2] == 1 else 1
471 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
472 selected_inc = testGen.rng.choice(inc_choices)
473 ifm_shape[1] += selected_inc[0]
474 ifm_shape[2] += selected_inc[1]
475
James Ward30124a82023-02-02 14:56:33 +0000476 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000477
478 return [ifm_shape]
479
480 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100481 def tgFullyConnected(testGen, op, rank, error_name=None):
482 pl, const = op["operands"]
483
484 if error_name != ErrorIf.WrongRank:
485 assert rank == 2
486
487 input_shape = testGen.makeShape(rank)
488
489 # Constrict the overall size of the shape when creating ERROR_IF tests
490 if error_name:
491 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
492
493 filter_oc = testGen.rng.integers(
494 low=testGen.args.tensor_shape_range[0],
495 high=testGen.args.tensor_shape_range[1],
496 size=1,
497 )[0]
498 filter_shape = np.asarray([filter_oc, input_shape[1]])
499
500 bias_shape = np.asarray([filter_oc])
501
502 return [input_shape, filter_shape, bias_shape]
503
504 @staticmethod
505 def tgMatmul(testGen, op, rank, error_name=None):
506 pl, const = op["operands"]
507
508 if error_name != ErrorIf.WrongRank:
509 assert rank == 3
510 assert pl == 2 and const == 0
511
512 a_shape = testGen.makeShape(rank)
513
514 # Constrict the overall size of the shape when creating ERROR_IF tests
515 if error_name:
516 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
517
518 # Get a random number for b_oc even if target shape is defined
519 b_oc = np.int32(
520 testGen.rng.integers(
521 low=testGen.args.tensor_shape_range[0],
522 high=testGen.args.tensor_shape_range[1],
523 size=1,
524 )
525 )[0]
526 # If N or H is large let b_oc be 1 to reduce output tensor size
527 if max(a_shape) > 1000:
528 b_oc = 1
529
530 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
531 return [a_shape, b_shape]
532
533 @staticmethod
534 def tgConcat(testGen, opName, rank, error_name=None):
535 pl, const = opName["operands"]
536 shape = testGen.makeShape(rank)
537
538 # Create extra tensors to concat.
539 # Take into account value of pl when getting maximum number of concats
540 num_tensors = testGen.randInt(0, 4)
541 shape_list = []
542 for i in range(pl + const + num_tensors):
543 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
544 remove = testGen.rng.choice([True, False])
545 wrongShape = shape.copy()
546
547 if remove and len(shape) > 1:
548 wrongShape = wrongShape[1:]
549 else:
550 wrongShape = list(wrongShape)
551 wrongShape.append(testGen.rng.integers(1, 10))
552
553 shape_list.append(wrongShape)
554 else:
555 shape_list.append(shape.copy())
556
557 return shape_list
558
559 @staticmethod
560 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
561 if error_name in [
562 ErrorIf.AxisSmallerZero,
563 ErrorIf.AxisLargerRank,
564 ErrorIf.ConcatInputRankMismatch,
565 ]:
566 return shapeList
567
568 # Split concat shape along axis to allow for multiple const inputs
569 # without making too many large tensors
570 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
571 # If axis can't be split we still need to invalidate other dimensions
572 if error_name == ErrorIf.ConcatInputDimMismatch:
573 for shape in shapeList[1:]:
574 # Negative test shapeLists are created individually for each test,
575 # so no need to copy the shape before altering it.
576 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
577 return shapeList
578
579 # Create copy of shape we are going to split (so we don't alter shapeList)
580 shape = shapeList[0].copy()
581 # Add original shape as first input
582 new_shapeList = [shape.copy()]
583 length_on_axis = shape[axis]
584 remaining_length = length_on_axis
585 for i in range(len(shapeList) - 2):
586 # Calculate split on axis and remaining value
587 split_shape_val = int(shape[axis] / 2)
588 remaining_length = remaining_length - split_shape_val
589
590 # Append new shape, and set remaining shape
591 shape[axis] = split_shape_val
592 new_shapeList.append(shape.copy())
593
594 # invalidate dimensions
595 if error_name == ErrorIf.ConcatInputDimMismatch:
596 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
597 else:
598 shape[axis] = remaining_length
599
600 if i == len(shapeList) - 3:
601 new_shapeList.append(shape.copy())
602
603 return new_shapeList
604
605
606class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100607 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100608
609 def __init__(self):
610 pass
611
Jeremy Johnson1271c442023-09-05 11:39:26 +0100612 class TVGInfo:
613 """Enhanced tensor values information including data gen dict."""
614
615 def __init__(self, tensorList, dataGenDict):
616 self.tensorList = tensorList
617 self.dataGenDict = dataGenDict
618
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100619 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000620 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100621 pCount, cCount = op["operands"]
622
623 tens = []
624 tens.extend(
625 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
626 )
627 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
628
629 return tens
630
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100631 # Default high value for random numbers
632 TVG_FLOAT_HIGH_VALUE = {
633 DType.FP32: (1 << 128) - (1 << (127 - 23)),
634 DType.FP16: (1 << 16) - (1 << (15 - 10)),
635 DType.BF16: (1 << 128) - (1 << (127 - 7)),
636 }
637
Jeremy Johnson30476252023-11-20 16:15:30 +0000638 # Default lowest normal values for random numbers
639 TVG_FLOAT_LOW_VALUE = {
640 DType.FP32: np.exp2(-126),
641 DType.FP16: np.exp2(-14),
642 DType.BF16: np.exp2(-126),
643 }
644
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100645 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +0000646 def _get_data_range(testGen, dtype, highValueLookup, lowValueLookup=None):
647 # Return a tuple of (low,high) data range values for the given data
648 # type using a combination of per operator table limits, data limits
649 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000650 if dtype in highValueLookup:
Jeremy Johnson30476252023-11-20 16:15:30 +0000651 type_range = testGen.getDTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000652 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000653 if lowValueLookup is not None and dtype in lowValueLookup:
654 low_val = lowValueLookup[dtype]
655 else:
656 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000657 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000658 # respecting the default ranges if more/less than the low/high
659 # values
660 data_range = (
661 max(low_val, type_range[0]),
662 min(high_val, type_range[1]),
663 )
664 if data_range[0] > data_range[1]:
665 # Invalid data range from low to high created due to user
666 # constraints revert to using internal ranges as they are
667 # known to work
668 msg = f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
669 warnings.warn(msg)
670 data_range = (low_val, high_val)
671 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000672 return None
673
674 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100675 def tvgLazyGenDefault(
676 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
677 ):
678 # Variable inputs versus constants
679 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000680 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100681
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100682 if (
683 error_name is not None
684 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100685 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100686 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000687 # Fall back to internal data gen when dealing with unsupported types or ops
688 data_range = argsDict["data_range"] if "data_range" in argsDict else None
689 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000690 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000691 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000692 if "data_range_list" in argsDict:
693 data_range = argsDict["data_range_list"][idx]["range"]
694 roundMode = (
695 "round" in argsDict["data_range_list"][idx]
696 and argsDict["data_range_list"][idx]["round"] is True
697 )
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000698 # Ignore lazy data gen option and create data array using any range limits
699 arr = testGen.getRandTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000700 if roundMode:
701 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000702 if idx < pCount:
703 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
704 else:
705 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100706
Jeremy Johnson1271c442023-09-05 11:39:26 +0100707 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
708
709 # Create data generator meta-data
710 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100711 tens_data = {
712 "version": "0.1",
713 "tensors": {},
714 }
715 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100716 for idx, shape in enumerate(shapeList):
717
718 tens_meta = {}
719 tens_meta["generator"] = gtu.DataGenType(dg_type).name
720 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
721 tens_meta["shape"] = [int(i) for i in shape]
722 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100723 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100724
725 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100726 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100727 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100728 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100729
730 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
731 info = {}
732 # TODO - generate seed for this generator based on test
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100733 info["rng_seed"] = 42
Jeremy Johnson30476252023-11-20 16:15:30 +0000734
735 if "data_range_list" in argsDict:
736 data_range = argsDict["data_range_list"][idx]["range"]
737 if "round" in argsDict["data_range_list"][idx]:
738 info["round"] = argsDict["data_range_list"][idx]["round"]
739 elif "data_range" in argsDict:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100740 data_range = argsDict["data_range"]
741 else:
742 data_range = testGen.getDTypeRange(
743 dtypeList[idx], high_inclusive=True
744 )
745 info["range"] = [str(v) for v in data_range]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100746 tens_meta["pseudo_random_info"] = info
747 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
748 info = {}
749 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100750 info["ks"] = int(argsDict["ks"])
751 if "acc_type" in argsDict:
752 # Convert type number into JSON name
753 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
754 "json"
755 ]
756 if "kernel" in argsDict:
757 info["kernel"] = [int(k) for k in argsDict["kernel"]]
758 if "axis" in argsDict:
759 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100760 tens_meta["dot_product_info"] = info
761 else:
762 # TODO - other data gen type
763 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100764
765 # Using the finished generate config meta data - generate the data if
766 # needed and assign a tensor name from the serializer
767
768 # Need to generate data when not lazy or for the bias tensor as we need
769 # to work out if the bias data is non-zero for compliance
770 if not testGen.args.lazy_data_gen or (
771 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
772 ):
773 # Give this tensor a temporary name until we get one from the serializer
774 temp_name = f"placeholder_{idx}"
775 dg_tens_meta[temp_name] = tens_meta
776 # Create data now using the temporary name to access meta details
777 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
778 # Remove the item as we will give it the correct name later
779 del dg_tens_meta[temp_name]
780
781 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
782 # The KS value used by compliance verification is altered when the
783 # bias data is non-zero
784 if max(abs(data)) > 0.0:
785 argsDict["ksb"] = argsDict["ks"] + 1
786
787 if testGen.args.lazy_data_gen:
788 data = None
789
790 if tens_meta["input_type"] == "VARIABLE":
791 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
792 else:
793 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
794
795 tens_ser_list.append(tens)
796 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100797 dg_tens_meta[tens.name] = tens_meta
798
Jeremy Johnson1271c442023-09-05 11:39:26 +0100799 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
800
801 @staticmethod
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000802 def tvgNegate(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100803 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000804 # Integer test
805 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100806 pCount, cCount = op["operands"]
807 assert (
808 pCount == 1 and cCount == 0
809 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100810 # Must create tensors with values within accumulator (int32) negatable
811 # range
812 max_val = (1 << 31) - 1
813 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100814 arr = np.int32(
815 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
816 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000817 tens_ser_list = []
818 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100819 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
820 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000821 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100822 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000823 # ERROR_IF or floating point test
824 return TosaTensorValuesGen.tvgLazyGenDefault(
825 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100826 )
827
Jeremy Johnson30476252023-11-20 16:15:30 +0000828 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000829 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
830 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
831 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
832 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
833 }
834
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100835 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000836 def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100837 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000838 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100839 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000840 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100841 pCount, cCount = op["operands"]
842 assert (
843 pCount == 2 and cCount == 0
844 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000845 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100846 add = op["op"] == Op.ADD
847 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
848 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
849 if add:
850 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
851 else:
852 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
853
854 # Work out the saturation limits
855 max_i32 = (1 << 31) - 1
856 min_i32 = -(1 << 31)
857 max_arr = np.full(shapeList[1], max_i32)
858 min_arr = np.full(shapeList[1], min_i32)
859
860 # Find how much values exceed the maximum/minimums
861 sat_max_arr = np.maximum(res_arr - max_arr, 0)
862 sat_min_arr = np.minimum(res_arr - min_arr, 0)
863
864 if not add:
865 # Swap saturation values and negate values as we need to perform opposite operations
866 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
867
868 # Create new array of unsaturated values by clipping values as needed
869 b_unsat_arr = b_arr
870 if (sat_max_arr != 0).any():
871 # Clip values that cause saturation
872 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
873 # Reduce axes in unsaturated tensor to match original tensor
874 for axis, dim in enumerate(b_arr.shape):
875 if dim != b_unsat_arr.shape[axis]:
876 assert (
877 dim == 1
878 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
879 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
880
881 if (sat_min_arr != 0).any():
882 # Clip values that cause saturation
883 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
884 # Reduce axes in unsaturated tensor to match original tensor
885 for axis, dim in enumerate(b_arr.shape):
886 if dim != b_unsat_arr.shape[axis]:
887 assert (
888 dim == 1
889 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
890 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
891
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000892 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100893 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
894 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000895 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100896 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
897 )
898
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000899 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100900 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000901 # ERROR_IF or floating point test
902 data_range = TosaTensorValuesGen._get_data_range(
903 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
904 )
905 if data_range:
906 argsDict["data_range"] = data_range
907
908 return TosaTensorValuesGen.tvgLazyGenDefault(
909 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100910 )
911
912 @staticmethod
913 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000914 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100915 ):
916 if dtypeList[0] in (
917 DType.INT32,
918 DType.INT16,
919 DType.INT8,
920 ):
921 # Limit input tensors with cond_if_binary or while_loop to stop
922 # saturation of add/sub ops with int32 and keep all logical shift
923 # values between 0 to 31 for int16 or int8
924 pCount, cCount = op["operands"]
925 pRemain = pCount
926 placeholders = []
927 for idx, shape in enumerate(shapeList[:]):
928 if dtypeList[0] == DType.INT32:
929 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
930 else:
931 arr = np.int32(
932 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
933 )
934 if pRemain > 0:
935 placeholders.append(
936 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
937 )
938 pRemain -= 1
939 else:
940 placeholders.append(
941 testGen.ser.addConst(shape, dtypeList[idx], arr)
942 )
943
944 return placeholders
945 else:
946 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000947 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100948 )
949
950 @staticmethod
951 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000952 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100953 ):
954 pCount, cCount = op["operands"]
955 # Force value of operand[1] to be within [0, num_bits]
956 assert (
957 pCount == 2 and cCount == 0
958 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
959
960 placeholders = []
961 for idx, shape in enumerate(shapeList[:]):
962 if idx == 1:
963 if dtypeList[idx] == DType.INT8:
964 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
965 elif dtypeList[idx] == DType.INT16:
966 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
967 elif dtypeList[idx] == DType.INT32:
968 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
969 elif error_name == ErrorIf.WrongInputType:
970 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
971 else:
972 raise Exception("OpArithmeticRightShift: invalid input dtype")
973 else:
974 arr = testGen.getRandTensor(shape, dtypeList[idx])
975 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
976
977 return placeholders
978
979 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000980 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100981 # Set datatype of condition tensor to boolean
982 dtypeList[0] = DType.BOOL
983
984 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000985 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100986 )
987
988 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000989 def tvgIntDiv(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100990 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000991 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100992 pCount, cCount = op["operands"]
993 assert (
994 pCount == 2 and cCount == 0
995 ), "Op.INTDIV must have 2 placeholders, 0 consts"
996
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000997 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100998
999 # Two invalid cases for Op.INTDIV:
1000 # 1. divisor == 0
1001 # 2. dividend == -(1<<31) and divisor == -1
1002 while True:
1003 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1004 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
1005
1006 if (divisor_arr == 0).any():
1007 continue
1008
1009 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1010 continue
1011
1012 break
1013
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001014 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001015 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1016 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001017 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001018 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1019 )
1020
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001021 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001022 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001023 return TosaTensorValuesGen.tvgLazyGenDefault(
1024 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001025 )
1026
Jeremy Johnson30476252023-11-20 16:15:30 +00001027 # Set the MUL data range to the square root of the largest value
1028 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001029 TVG_FLOAT_HIGH_VALUE_MUL = {
1030 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1031 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1032 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1033 }
1034
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001035 @staticmethod
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001036 def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1037 if error_name is not None or dtypeList[0] in (
1038 DType.FP16,
1039 DType.BF16,
1040 DType.FP32,
1041 ):
1042 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001043 data_range = TosaTensorValuesGen._get_data_range(
1044 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
1045 )
1046 if data_range:
1047 argsDict["data_range"] = data_range
1048
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001049 return TosaTensorValuesGen.tvgLazyGenDefault(
1050 testGen, opName, dtypeList, shapeList, argsDict, error_name
1051 )
1052 else:
1053 # Integer test
1054 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001055 pCount, cCount = op["operands"]
1056 assert (
1057 pCount == 2 and cCount == 0
1058 ), "Op.MUL must have 2 placeholders, 0 consts"
1059
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001060 tens_ser_list = []
1061
1062 # Make sure multiply result in int32 range
1063 shift = argsDict["shift"]
1064 if dtypeList[0] == DType.INT8:
1065 num_bits = 8
1066 elif dtypeList[0] == DType.INT16:
1067 num_bits = 16
1068 elif dtypeList[0] == DType.INT32:
1069 num_bits = 32
1070 elif error_name == ErrorIf.WrongInputType:
1071 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001072 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001073 raise Exception("OpMul: invalid input dtype")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001074
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001075 for idx, shape in enumerate(shapeList[:]):
1076 low = -(2 ** (num_bits - 1))
1077 high = (2 ** (num_bits - 1)) - 1
1078
1079 a_arr = np.int32(
1080 testGen.rng.integers(low=low, high=high, size=shapeList[0])
1081 )
1082 b_arr = np.int32(
1083 testGen.rng.integers(low=low, high=high, size=shapeList[1])
1084 )
1085
1086 i = 0
1087 while True:
1088
1089 a_arr_64 = a_arr.astype(np.int64)
1090 b_arr_64 = b_arr.astype(np.int64)
1091
1092 if shift > 0:
1093 rounding = 1 << (shift - 1)
1094 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001095 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001096 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001097
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001098 if (result_arr > -(2**31)).all() and (
1099 result_arr <= ((2**31) - 1)
1100 ).all():
1101 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001102
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001103 i = i + 1
1104 a_arr = a_arr // 2
1105 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001106
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001107 tens_ser_list.append(
1108 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001109 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001110 tens_ser_list.append(
1111 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1112 )
1113
1114 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001115
1116 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001117 def tvgConcat(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001118 count = len(shapeList) - testGen.args.num_const_inputs_concat
1119 if count < 1:
1120 count = 1
1121 if testGen.args.num_const_inputs_concat == 0:
1122 count = len(shapeList)
1123
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001124 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001125 testGen, shapeList, argsDict["axis"], error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001126 )
1127
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001128 tens_ser_list = []
1129 tens_ser_list.extend(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001130 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
1131 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001132 tens_ser_list.extend(
1133 testGen.buildConstTensors(shapeList[count:], dtypeList[count:])
1134 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001135
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001136 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001137
1138 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001139 def tvgLogicalShift(
1140 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1141 ):
1142 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001143 pCount, cCount = op["operands"]
1144 assert (
1145 pCount == 2 and cCount == 0
1146 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
1147 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1148 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001149 tens_ser_list = []
1150 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001151 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1152 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001153 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001154 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1155 )
1156
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001157 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001158
1159 @staticmethod
Jeremy Johnsona0150012023-11-15 15:52:06 +00001160 def tvgEqual(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1161 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1162 # Integer
1163 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001164 pCount, cCount = op["operands"]
1165 assert (
1166 pCount == 2 and cCount == 0
1167 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001168
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001169 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1170 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001171
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001172 # Using random numbers means that it will be very unlikely that
1173 # there are any matching (equal) values, therefore force that
1174 # there are twice the number of matching values as the tensor rank
1175 for num in range(0, len(shapeList[0]) * 2):
1176 a_index = []
1177 b_index = []
1178 # Choose an index in each axis for the whole shape
1179 for axis in range(0, len(shapeList[0])):
1180 # Index can be up to the largest dimension in both shapes
1181 index = np.int32(
1182 testGen.rng.integers(
1183 0, max(shapeList[0][axis], shapeList[1][axis])
1184 )
1185 )
1186 # Reduce the index down to a shape's dim for broadcasting
1187 a_index.append(min(shapeList[0][axis] - 1, index))
1188 b_index.append(min(shapeList[1][axis] - 1, index))
1189
1190 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1191
Jeremy Johnsona0150012023-11-15 15:52:06 +00001192 tens_ser_list = []
1193 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001194 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1195 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001196 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001197 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1198 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001199 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001200 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001201 # ERROR_IF or floating point test
1202 return TosaTensorValuesGen.tvgLazyGenDefault(
1203 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001204 )
1205
1206 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001207 def tvgReduceSum(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001208 dtype = dtypeList[0]
1209 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001210 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001211 pCount, cCount = op["operands"]
1212 assert (
1213 pCount == 1 and cCount == 0
1214 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1215 # Limit values so that the sum cannot exceed the range of an int32 during
1216 # summation of any axis
1217 range_val = int((1 << 31) / max(shapeList[0]))
1218 values_arr = np.int32(
1219 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1220 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001221 tens_ser_list = []
1222 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001223 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001224 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001225 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001226 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001227 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001228 if (
1229 error_name is None
1230 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1231 ):
1232 # Limit ranges for (non error & non compliance) tests by using
1233 # values that can be summed on any axis to not hit infinity
1234 highval_lookup = {
1235 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1236 / max(shapeList[0])
1237 }
1238 data_range = TosaTensorValuesGen._get_data_range(
1239 testGen, dtype, highval_lookup
1240 )
1241 assert data_range is not None
1242 argsDict["data_range"] = data_range
1243
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001244 return TosaTensorValuesGen.tvgLazyGenDefault(
1245 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001246 )
1247
Jeremy Johnson30476252023-11-20 16:15:30 +00001248 # Set the POW exponent high data range
1249 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1250 DType.FP32: 10.0,
1251 DType.FP16: 10.0,
1252 DType.BF16: 10.0,
1253 }
1254 # POW highest base value (within a safe margin of error) that can be raised
1255 # to +ve exponent that doesn't become Infinity
1256 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1257 DType.FP32: math.floor(
1258 math.pow(
1259 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1260 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1261 )
1262 ),
1263 DType.FP16: math.floor(
1264 math.pow(
1265 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1266 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1267 )
1268 ),
1269 DType.BF16: math.floor(
1270 math.pow(
1271 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1272 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1273 )
1274 ),
1275 }
1276 # POW lowest base value (within a safe margin of error) that can be raised
1277 # to -ve exponent that doesn't become Infinity
1278 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1279 DType.FP32: math.ceil(
1280 math.pow(
1281 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1282 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1283 )
1284 * 1000
1285 )
1286 / 1000,
1287 DType.FP16: math.ceil(
1288 math.pow(
1289 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1290 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1291 )
1292 * 1000
1293 )
1294 / 1000,
1295 DType.BF16: math.ceil(
1296 math.pow(
1297 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1298 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1299 )
1300 * 1000
1301 )
1302 / 1000,
1303 }
1304
1305 @staticmethod
1306 def tvgPow(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1307 if error_name is not None:
1308 return TosaTensorValuesGen.tvgLazyGenDefault(
1309 testGen, opName, dtypeList, shapeList, argsDict, error_name
1310 )
1311 dtype = dtypeList[0]
1312 # Different ranges for POW
1313 test_set = argsDict["s"]
1314 if test_set == 0:
1315 # Positive base with fractional exponent
1316 base_range = TosaTensorValuesGen._get_data_range(
1317 testGen,
1318 dtype,
1319 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1320 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1321 )
1322 exp_range = TosaTensorValuesGen._get_data_range(
1323 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1324 )
1325 exp_round = False
1326 else:
1327 # Integer exponent
1328 exp_range = TosaTensorValuesGen._get_data_range(
1329 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1330 )
1331 exp_round = True
1332 if test_set == 1:
1333 # Positive base
1334 base_range = TosaTensorValuesGen._get_data_range(
1335 testGen,
1336 dtype,
1337 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1338 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1339 )
1340 else:
1341 assert test_set == 2
1342 # Negative base
1343 # Supply new look up tables with negative values
1344 base_range = TosaTensorValuesGen._get_data_range(
1345 testGen,
1346 dtype,
1347 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1348 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1349 )
1350
1351 data_range_list = (
1352 {
1353 "range": base_range,
1354 },
1355 {
1356 "range": exp_range,
1357 "round": exp_round,
1358 },
1359 )
1360 argsDict["data_range_list"] = data_range_list
1361 return TosaTensorValuesGen.tvgLazyGenDefault(
1362 testGen, opName, dtypeList, shapeList, argsDict, error_name
1363 )
1364
1365 @staticmethod
1366 def tvgLogRsqrt(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1367 # LOG & RSQRT data range from lowest expressible positive number to
1368 # largest to avoid NaNs
1369 data_range = TosaTensorValuesGen._get_data_range(
1370 testGen,
1371 dtypeList[0],
1372 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1373 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1374 )
1375 if data_range:
1376 argsDict["data_range"] = data_range
1377
1378 return TosaTensorValuesGen.tvgLazyGenDefault(
1379 testGen, opName, dtypeList, shapeList, argsDict, error_name
1380 )
1381
1382 # Set the EXP data range to the log of the largest to smallest values
1383 # to avoid infinities or making the result zero
1384 TVG_FLOAT_HIGH_VALUE_EXP = {
1385 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1386 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1387 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1388 }
1389 TVG_FLOAT_LOW_VALUE_EXP = {
1390 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1391 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1392 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1393 }
1394
1395 @staticmethod
1396 def tvgExp(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1397 data_range = TosaTensorValuesGen._get_data_range(
1398 testGen,
1399 dtypeList[0],
1400 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1401 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1402 )
1403 if data_range:
1404 argsDict["data_range"] = data_range
1405
1406 return TosaTensorValuesGen.tvgLazyGenDefault(
1407 testGen, opName, dtypeList, shapeList, argsDict, error_name
1408 )
1409
1410 @staticmethod
1411 def tvgFullyConnected(
1412 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1413 ):
1414 dtype = dtypeList[0]
1415 if (
1416 error_name is None
1417 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001418 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001419 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001420 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001421 # Limit ranges for (non error & non compliance) FP tests by using
1422 # values that can be multiplied on any axis to not hit infinity/NaN
1423 IC = shapeList[0][1]
1424 highval_lookup = {
1425 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1426 }
1427 data_range = TosaTensorValuesGen._get_data_range(
1428 testGen, dtype, highval_lookup
1429 )
1430 assert data_range is not None
1431 argsDict["data_range"] = data_range
1432
1433 return TosaTensorValuesGen.tvgLazyGenDefault(
1434 testGen, opName, dtypeList, shapeList, argsDict, error_name
1435 )
1436
Jeremy Johnson708da822023-11-15 16:25:45 +00001437 @staticmethod
1438 def tvgCast(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1439 in_dtype = dtypeList[0]
1440 out_dtype = argsDict["out_type"]
1441 # Create look up to limit input tensor to output type maximums to avoid
1442 # FP infinities and saturation of integers
1443 out_range = testGen.getDTypeRange(out_dtype, high_inclusive=True)
1444 highval_lookup = {in_dtype: out_range[1]}
1445 data_range = TosaTensorValuesGen._get_data_range(
1446 testGen,
1447 in_dtype,
1448 highval_lookup,
1449 )
1450
1451 assert data_range is not None
1452 argsDict["data_range"] = data_range
1453
1454 return TosaTensorValuesGen.tvgLazyGenDefault(
1455 testGen, opName, dtypeList, shapeList, argsDict, error_name
1456 )
1457
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001458
1459class TosaArgGen:
1460 """Argument generators create exhaustive or random lists of attributes for
1461 operators that take attributes or other parameters.
1462
1463 The return value is a list of (descriptive_name, [arglist]) tuples where
1464 the descriptive_name is appended to the test name and the arglist is expanded
1465 as arguments to the operator build function.
1466 """
1467
1468 def __init__(self):
1469 pass
1470
1471 @staticmethod
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001472 def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001473 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001474 if (
1475 error_name is None
1476 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1477 and gtu.dtypeIsSupportedByCompliance(dtype)
1478 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001479 if dtype in [DType.FP16, DType.FP32, DType.BF16]:
1480 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1481 else:
1482 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1483 else:
1484 # Error test or No data generator types listed - assume random
1485 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1486
1487 # Expand arg list with other data generator types
1488 new_arg_list = []
1489 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001490 for arg_str, args_dict in arg_list:
1491 args_dict["dg_type"] = dg_type
Jeremy Johnson1271c442023-09-05 11:39:26 +01001492 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001493 if error_name is None:
1494 num_test_sets = (
1495 args_dict["num_test_sets"]
1496 if "num_test_sets" in args_dict
1497 else 0
1498 )
1499 else:
1500 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001501
1502 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1503 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001504 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001505 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001506 shape_info = (
1507 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1508 if "shape" in args_dict
1509 else ""
1510 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001511 print(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001512 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001513 )
1514 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001515 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001516 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001517 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001518
Jeremy Johnson30476252023-11-20 16:15:30 +00001519 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1520
1521 if num_test_sets > 0:
1522 for s in range(0, num_test_sets):
1523 new_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001524 new_args_dict = args_dict.copy()
1525 new_args_dict["s"] = s
1526 new_arg_list.append((new_arg_str, new_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001527 else:
1528 # Default is a single test
1529 new_arg_list.append((arg_str, args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001530
1531 return new_arg_list
1532
1533 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001534 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1535 """A trivial argument generator for operators that don't take any
1536 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001537 arg_list = TosaArgGen._add_data_generators(
1538 testGen,
1539 opName,
1540 dtype,
1541 [("", {})],
1542 error_name,
1543 )
1544 # Return list of tuples: (arg_str, args_dict)
1545 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001546
1547 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +00001548 def agPow(testGen, opName, shapeList, dtype, error_name=None):
1549 """Pow operator needs different test sets to cover random numbers
1550 without creating NaNs or Infs"""
1551 arg_list = TosaArgGen._add_data_generators(
1552 testGen,
1553 opName,
1554 dtype,
1555 [("", {"num_test_sets": 3})],
1556 error_name,
1557 )
1558 # Return list of tuples: (arg_str, args_dict)
1559 return arg_list
1560
1561 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001562 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1563 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001564 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001565 shape = shapeList[0]
1566
1567 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001568 # Set too small axis
1569 axes = [testGen.rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001570 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001571 # Set too large axis
1572 axes = [testGen.rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001573 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001574 # Create tests for each dimension
1575 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001576
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001577 opid = testGen.TOSA_OP_LIST[opName]["op"]
1578
1579 for a in axes:
1580 args_dict = {"axis": int(a)}
1581 if opid == Op.REDUCE_SUM:
1582 args_dict["dot_products"] = gtu.product(shape)
1583 args_dict["shape"] = shape
1584 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1585 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1586
1587 arg_list.append(("axis{}".format(a), args_dict))
1588
1589 arg_list = TosaArgGen._add_data_generators(
1590 testGen,
1591 opName,
1592 dtype,
1593 arg_list,
1594 error_name,
1595 )
1596 # Return list of tuples: (arg_str, args_dict)
1597 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001598
1599 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001600 def _calculate_sparsity(num_tests, sparsity_factor):
1601 sparsity = num_tests // sparsity_factor + 1
1602 # If there are only a small number of tests, just select them all
1603 if sparsity < 13:
1604 sparsity = 1
1605 # To get a variety of parameter combinations sparsity should not be a
1606 # multiple of 2, 3 or 5
1607 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1608 sparsity += 1
1609 return sparsity
1610
1611 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001612 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001613 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001614 arg_list = []
1615
Jeremy Johnson0c716862023-04-13 17:18:19 +01001616 if testGen.args.level8k and error_name is not None:
1617 # Don't produce negative large tests
1618 return arg_list
1619
1620 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001621 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001622 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001623 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001624
Jeremy Johnson1271c442023-09-05 11:39:26 +01001625 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001626
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001627 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01001628 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001629 depthwise = opName.startswith("depthwise")
1630
1631 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001632 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001633 if error_name != ErrorIf.WrongRank:
1634 assert len(ifm_shape) == rank
1635 assert len(filter_shape) == rank
1636
Jeremy Johnson0c716862023-04-13 17:18:19 +01001637 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001638 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001639 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01001640 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001641 # compliance size - KS
1642 k_size = gtu.product(k_shape)
1643 if not depthwise:
1644 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001645
Jeremy Johnson0c716862023-04-13 17:18:19 +01001646 if not testGen.args.level8k:
1647 # Generate comprehensive argument lists
1648 # - except for named errors, which use specific invalid value(s)
1649 if error_name == ErrorIf.PadSmallerZero:
1650 p_vals = [testGen.rng.choice(range(-5, 0))]
1651 else:
1652 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1653 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1654 if error_name == ErrorIf.StrideSmallerOne:
1655 # Can't use stride=0, as it is used to derive output shape, as a divisor
1656 s_vals = [testGen.rng.choice(range(-5, 0))]
1657 else:
1658 # Stride must be greater than 1 to force non-integer error
1659 startStride = (
1660 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001661 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001662 s_vals = [
1663 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1664 ]
1665 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1666 if error_name == ErrorIf.DilationSmallerOne:
1667 d_vals = [testGen.rng.choice(range(-5, 1))]
1668 else:
1669 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1670 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001671
Jeremy Johnson0c716862023-04-13 17:18:19 +01001672 if not error_name and testGen.args.oversize:
1673 # add some oversize argument values
1674 if max(ifm_shape) < 64:
1675 bigPadding = 9
1676 paddings.update(
1677 {
1678 x
1679 for x in itertools.product(
1680 *([[0, bigPadding]] * (k_rank * 2))
1681 )
1682 }
1683 )
1684 bigStride = 8
1685 strides.update(
1686 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1687 )
1688 bigDilation = 7
1689 dilations.update(
1690 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1691 )
1692 max_dim_size = None
1693
1694 # There are too many parameter combinations, so generate them sparsely,
1695 # very sparse for negative tests
1696 sparsity_factor = 2 if error_name else 120
1697 sparsity = TosaArgGen._calculate_sparsity(
1698 len(paddings) * len(strides) * len(dilations), sparsity_factor
1699 )
1700 else:
1701 # Only test 8k levels boundaries
1702 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1703 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1704 bigPadding = bigKernel
1705
1706 dilation_shape = [1] * k_rank
1707 pad_shape = [0] * k_rank * 2
1708 if conv3d:
1709 # Small stride apart from for big kernel (see below) to keep
1710 # tensor size/calculation small
1711 stride_shape = [1] * k_rank
1712 for idx in range(k_rank):
1713 pad_offset = idx * 2
1714 if k_shape[idx] == bigKernel:
1715 # Padding shape needs to account for tensor shape
1716 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1717 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1718 # Big stride to reduce output size
1719 stride_shape[idx] = bigKernel
1720 else:
1721 # Account for kernel size
1722 pad_shape[pad_offset] = k_shape[idx] - 1
1723 else:
1724 # Always have a large stride with extra padding and dilation to keep
1725 # tensor calculation reasonable
1726 stride_shape = [bigKernel] * k_rank
1727 for idx in range(k_rank):
1728 # Dilation shape must account for kernel size
1729 dilation_shape[idx] = bigKernel // k_shape[idx]
1730 # Padding shape needs to accommodate tensor/kernel & dilation
1731 pad_offset = idx * 2
1732 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1733 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1734
1735 strides = {tuple(stride_shape)}
1736 dilations = {tuple(dilation_shape)}
1737 paddings = {tuple(pad_shape)}
1738 # Create a limit for the output dimensions size
1739 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1740
1741 # Currently allow all combinations that are reasonable size
1742 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001743
1744 n = 0
1745 for s in sorted(list(strides)):
1746 for p in sorted(list(paddings)):
1747 for d in sorted(list(dilations)):
1748 if (
1749 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001750 # the padded shape must exceed the dilation * kernel to get a positive
1751 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001752 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
1753 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01001754 and (
1755 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01001756 or (
1757 (ifm_shape[3] - 1 + p[4] + p[5])
1758 > d[2] * (k_shape[2] - 1)
1759 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01001760 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001761 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001762 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01001763 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001764 for index in range(k_rank):
1765 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001766 partial = (
1767 ifm_shape[index + 1]
1768 - 1
1769 + p[pad_offset]
1770 + p[pad_offset + 1]
1771 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001772 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001773 remainders.append(partial % s[index])
1774 outputs.append((partial // s[index]) + 1)
1775
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001776 if (
1777 # the parameters must produce integer exact output
1778 error_name != ErrorIf.ConvOutputShapeNonInteger
1779 and max(remainders) == 0
1780 ) or (
1781 error_name == ErrorIf.ConvOutputShapeNonInteger
1782 and max(remainders) > 0
1783 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001784 if (
1785 max_dim_size is not None
1786 and max(outputs) >= max_dim_size
1787 ):
1788 # Test will consume too much memory - skip it
1789 continue
1790
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001791 # Compliance - number of dot product calculations
1792 if depthwise:
1793 # TODO - add support
1794 dots = 0
1795 else:
1796 dots = gtu.product(
1797 (ifm_shape[0], *outputs, filter_shape[0])
1798 )
1799 args_dict = {
1800 "acc_type": accum_dtype,
1801 "stride": s,
1802 "pad": p,
1803 "dilation": d,
1804 "kernel": k_shape,
1805 "ks": k_size,
1806 "dot_products": dots,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001807 "shape": ifm_shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001808 }
1809
Jeremy Johnson0c716862023-04-13 17:18:19 +01001810 # Support for larger values than 9 needs different delimiter
1811 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001812 arg_list.append(
1813 (
James Ward8b390432022-08-12 20:48:56 +01001814 "acc{}_st{}_pad{}_dilat{}".format(
1815 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001816 delim.join([str(x) for x in s]),
1817 delim.join([str(x) for x in p]),
1818 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001819 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001820 args_dict,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001821 )
1822 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001823 n += 1
1824
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001825 arg_list = TosaArgGen._add_data_generators(
1826 testGen,
1827 opName,
1828 dtypes[0],
1829 arg_list,
1830 error_name,
1831 )
1832 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001833 return arg_list
1834
1835 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001836 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1837
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001838 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001839 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01001840
1841 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001842 accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01001843 elif error_name == ErrorIf.WrongInputType:
1844 # Pick some potentially correct output dtype if input type is incorrect
1845 accum_dtype = DType.INT32
1846 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001847 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001848
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001849 # Set up compliance info
1850 args_dict = {
1851 "acc_type": accum_dtype,
1852 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
1853 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
1854 "shape": shapeList[0],
1855 }
1856
1857 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
1858
1859 arg_list = TosaArgGen._add_data_generators(
1860 testGen,
1861 opName,
1862 input_dtype,
1863 arg_list,
1864 error_name,
1865 )
1866 # Return list of tuples: (arg_str, args_dict)
1867 return arg_list
James Ward8b390432022-08-12 20:48:56 +01001868
1869 @staticmethod
1870 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
1871 # Get valid accumulate type(s)
1872 if dtype == DType.INT8:
1873 accum_dtypes = [DType.INT32]
1874 elif dtype == DType.INT16:
1875 accum_dtypes = [DType.INT48]
1876 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001877 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001878 elif dtype == DType.BF16:
1879 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001880 elif dtype == DType.FP32:
1881 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001882 elif error_name is None:
1883 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
1884
1885 if error_name == ErrorIf.WrongOutputType:
1886 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson1271c442023-09-05 11:39:26 +01001887 accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01001888 elif error_name == ErrorIf.WrongInputType:
1889 # Pick some potentially correct output dtype if input type is incorrect
1890 accum_dtypes = [DType.INT32]
1891
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001892 # Set up compliance info
1893 args_dict = {
1894 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
1895 # Set dot_products = N*H*W
1896 "dot_products": gtu.product(
1897 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
1898 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001899 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001900 }
1901
1902 # Create arg tuple of string and dict
1903 arg_list = []
1904 for a in accum_dtypes:
1905 d = args_dict.copy()
1906 d["acc_type"] = a
1907 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001908
1909 arg_list = TosaArgGen._add_data_generators(
1910 testGen,
1911 opName,
1912 dtype,
1913 arg_list,
1914 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001915 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001916 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001917 return arg_list
James Ward8b390432022-08-12 20:48:56 +01001918
1919 @staticmethod
1920 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001921 arg_list = []
1922
Jeremy Johnson0c716862023-04-13 17:18:19 +01001923 if testGen.args.level8k and error_name is not None:
1924 # Don't produce negative large tests
1925 return arg_list
1926
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001927 ifm_shape = shapeList[0]
1928 filter_shape = shapeList[1]
1929
Jeremy Johnson1271c442023-09-05 11:39:26 +01001930 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001931
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001932 # Must be rank 4
1933 if error_name != ErrorIf.WrongRank:
1934 assert len(ifm_shape) == 4
1935 assert len(filter_shape) == 4
1936
Jeremy Johnson0c716862023-04-13 17:18:19 +01001937 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001938
Jeremy Johnson0c716862023-04-13 17:18:19 +01001939 if not testGen.args.level8k:
1940 # Generate comprehensive argument lists
1941 # - except for named errors, which use specific invalid value(s)
1942 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
1943 if error_name == ErrorIf.PadLargerEqualKernel:
1944 max_filter_size = -max(k_shape[0], k_shape[1])
1945 p_vals = [
1946 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
1947 ]
1948 else:
1949 p_vals = [
1950 x
1951 for x in range(
1952 smallest_padding_size, testGen.args.max_conv_padding + 1
1953 )
1954 ]
1955 paddings = {x for x in itertools.product(*([p_vals] * 4))}
1956 if error_name == ErrorIf.StrideSmallerOne:
1957 # Can't use stride=0, as it is used to derive output shape, as a divisor
1958 s_vals = [testGen.rng.choice(range(-5, 0))]
1959 else:
1960 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
1961 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001962
Jeremy Johnson0c716862023-04-13 17:18:19 +01001963 if not error_name and testGen.args.oversize:
1964 # add some oversize argument values
1965 if max(ifm_shape) < 64:
1966 bigPadding = 9
1967 paddings.update(
1968 {
1969 x
1970 for x in itertools.product(
1971 *([[smallest_padding_size, bigPadding]] * 4)
1972 )
1973 }
1974 )
1975 bigStride = 8
1976 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
1977
1978 # There are too many parameter combinations, so generate them sparsely,
1979 # very sparse for negative tests
1980 sparsity_factor = 2 if error_name else 10
1981 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
1982 # If there are only a small number of tests, just select them all
1983 if sparsity < 13:
1984 sparsity = 1
1985 # To get a variety of parameter combinations sparsity should not be a
1986 # multiple of 2, 3 or 5
1987 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1988 sparsity += 1
1989 else:
1990 # Only test 8k levels boundaries
1991 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1992 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1993 bigPadding = bigKernel
1994
1995 pad_shape = [0] * (len(k_shape) * 2)
1996 stride_shape = [1] * len(k_shape)
1997 # The point at which input dimension combined with the stride will
1998 # create large output sizes!
1999 LARGE_SIZE = 2
2000 for idx in range(len(k_shape)):
2001 pad_offset = idx * 2
2002 if k_shape[idx] == bigKernel:
2003 # Set large stride
2004 stride_shape[idx] = bigKernel
2005 # Use negative output padding to reduce shape size
2006 pad_shape[pad_offset] = -(bigPadding - 1)
2007 if ifm_shape[idx + 1] > LARGE_SIZE:
2008 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2009 else:
2010 # The other dimension should be the bigKernel
2011 alt_idx = 1 - idx
2012 if (
2013 k_shape[alt_idx] == bigKernel
2014 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2015 ):
2016 # As the input is small, the large stride won't
2017 # affect the output so we can add some padding
2018 pad_shape[pad_offset + 1] = bigPadding
2019
2020 strides = {tuple(stride_shape)}
2021 paddings = {tuple(pad_shape)}
2022
2023 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002024 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002025
2026 n = 0
2027 for s in sorted(list(strides)):
2028 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07002029 if n % sparsity == 0:
2030 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002031 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2032 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07002033 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002034
2035 # Support for larger values than 9 needs different delimiter
2036 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07002037 arg_list.append(
2038 (
James Ward8b390432022-08-12 20:48:56 +01002039 "acc{}_st{}_pad{}_os{}".format(
2040 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002041 delim.join([str(x) for x in s]),
2042 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07002043 "x".join([str(x) for x in os]),
2044 ),
James Ward8b390432022-08-12 20:48:56 +01002045 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002046 )
TatWai Chong24594f52022-06-08 00:48:04 -07002047 )
2048 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002049
2050 return arg_list
2051
2052 @staticmethod
2053 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002054 rank = len(shapeList[0])
2055
2056 # Exhaustively test combinations of padding on each side of each dimension
2057 # - the range of padding values is defined by pad_min and pad_max
2058 # - for padding >9, the name format needs to be more distinctive
2059 pad_min, pad_max = 0, 1
2060 pad_values = [x for x in range(pad_min, pad_max + 1)]
2061 if error_name == ErrorIf.PadSmallerZero:
2062 pad_values = [x for x in range(-2, 0)]
2063 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2064 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2065
2066 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
2067 pad_const_int = testGen.getRandNumberDType(dtype)
2068 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00002069 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002070 pad_const_int = 0
2071 pad_const_fp = testGen.getRandNumberDType(dtype)
2072 else:
2073 return []
2074
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002075 list_shape_pad_values = list(shape_pad_values)
2076 # If we are producing tests for rank 6 or greater use sparsity
2077 if len(list_shape_pad_values) > 1024:
2078 sparsity_factor = 2 if error_name else 120
2079 sparsity = TosaArgGen._calculate_sparsity(
2080 len(list_shape_pad_values), sparsity_factor
2081 )
2082 else:
2083 sparsity = 1
2084
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002085 # Build arg list
2086 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002087 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002088 paddings = list(paddings)
2089 args_valid = True
2090
2091 if error_name == ErrorIf.PadSmallerZero:
2092 # Prevent negative output shapes while ensuring still testing for negative padding
2093 for i in range(rank):
2094 dim_after_padding = (
2095 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2096 )
2097 if dim_after_padding < 1:
2098 paddings[i] = (0, 0)
2099 if all([p > -1 for p in paddings[i]]):
2100 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002101 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01002102 name = "pad"
2103 for r in range(rank):
2104 before, after = paddings[r]
2105 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002106 args_dict = {
2107 "pad": np.array(paddings),
2108 "pad_const_int": pad_const_int,
2109 "pad_const_fp": pad_const_fp,
2110 }
2111 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002112
2113 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
2114 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002115
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002116 arg_list = TosaArgGen._add_data_generators(
2117 testGen,
2118 opName,
2119 dtype,
2120 arg_list,
2121 error_name,
2122 )
2123
2124 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002125 return arg_list
2126
2127 @staticmethod
2128 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
2129 arg_list = []
2130
2131 shape = shapeList[0]
2132 if error_name != ErrorIf.WrongRank:
2133 assert len(shape) == 4
2134
Jeremy Johnson0c716862023-04-13 17:18:19 +01002135 test_level8k = testGen.args.level8k and error_name is None
2136
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002137 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002138 startKernel = 2
2139 startPad = 0
2140 if not test_level8k:
2141 # Generate comprehensive argument lists
2142 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2143 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2144 # Stride must be greater than 1 to force non-integer error
2145 s_vals = [
2146 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2147 ]
2148 strides = {x for x in itertools.product(*([s_vals] * 2))}
2149 k_vals = [
2150 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2151 ]
2152 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2153 max_dim_size = None
2154 else:
2155 # Only test 8k levels
2156 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2157 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2158 strides = {(1, bigStride), (bigStride, 4)}
2159 kernels = {(1, bigKernel), (bigKernel, 3)}
2160 paddings = set()
2161 for s in sorted(list(strides)):
2162 for k in sorted(list(kernels)):
2163 padding = []
2164 for idx in range(len(k)):
2165 total_padding = s[idx] - shape[idx + 1] + k[idx]
2166 while total_padding < 0:
2167 # Must meet: shape + padding > kernel
2168 total_padding += s[idx]
2169 if total_padding < k[idx]:
2170 padding.extend([0, total_padding])
2171 else:
2172 # Note this may produce padding >= k[idx] which is not
2173 # allowed - but will be ignored in the creation loop below
2174 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2175 paddings.add(tuple(padding))
2176 # Create a limit for the output dimensions size
2177 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002178
James Ward8b390432022-08-12 20:48:56 +01002179 if opName == "max_pool2d":
2180 accum_dtypes = [None] # max_pool has no accumulate dtype
2181 elif dtype == DType.INT8 or dtype == DType.INT16:
2182 accum_dtypes = [DType.INT32]
2183 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002184 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002185 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002186 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01002187 elif error_name is None:
2188 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2189 else:
2190 # Set to something for the ErrorIf case which has
2191 # incorrect input data-type
2192 accum_dtypes = [DType.INT32]
2193
Jeremy Johnson0c716862023-04-13 17:18:19 +01002194 if not test_level8k:
2195 if testGen.args.oversize:
2196 # add some oversize argument values
2197 bigStride = 7
2198 bigKernel = 9
2199 strides.update(
2200 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002201 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002202 kernels.update(
2203 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2204 )
2205 if max(shape) < 64:
2206 # padding must be less than the kernel size
2207 bigPadding = bigKernel - 1
2208 paddings.update(
2209 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2210 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002211
Jeremy Johnson0c716862023-04-13 17:18:19 +01002212 # There are too many parameter combinations, so generate them sparsely,
2213 # very sparse for negative tests
2214 sparsity_factor = 2 if error_name else 500
2215 sparsity = (
2216 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2217 )
2218 else:
2219 # We have already limited test output combinations for 8k tests
2220 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002221
James Ward8b390432022-08-12 20:48:56 +01002222 arg_str = (
2223 "acc{}_st{}_kern{}_pad{}"
2224 if accum_dtypes[0] is not None
2225 else "st{}_kern{}_pad{}"
2226 )
2227
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002228 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002229 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002230 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002231
2232 # Support for larger values than 9 needs different delimiter
2233 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002234 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002235 delim.join([str(x) for x in stride]),
2236 delim.join([str(x) for x in kern]),
2237 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002238 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002239 args_dict = {
2240 "stride": stride,
2241 "pad": pad,
2242 "kernel": kern,
2243 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002244 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002245 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2246 }
James Ward8b390432022-08-12 20:48:56 +01002247
2248 if accum is not None:
2249 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002250 args_dict["acc_type"] = accum
2251 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002252
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002253 n = 0
James Ward8b390432022-08-12 20:48:56 +01002254 for a in accum_dtypes:
2255 for s in sorted(list(strides)):
2256 for p in sorted(list(paddings)):
2257 for k in sorted(list(kernels)):
2258 if error_name in [
2259 ErrorIf.StrideSmallerOne,
2260 ErrorIf.KernelSmallerOne,
2261 ErrorIf.PadSmallerZero,
2262 ErrorIf.PadLargerEqualKernel,
2263 ]:
2264 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
2265 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002266 )
James Ward8b390432022-08-12 20:48:56 +01002267 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002268 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002269 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002270 )
James Ward8b390432022-08-12 20:48:56 +01002271 elif (
2272 n % sparsity == 0
2273 # padding must not exceed the kernel size
2274 and p[0] < k[0]
2275 and p[1] < k[0]
2276 and p[2] < k[1]
2277 and p[3] < k[1]
2278 # the padded shape must exceed the kernel size
2279 and (shape[1] + p[0] + p[1]) > k[0]
2280 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002281 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002282 partial_h = shape[1] + p[0] + p[1] - k[0]
2283 partial_w = shape[2] + p[2] + p[3] - k[1]
2284 remainder_h = partial_h % s[0]
2285 remainder_w = partial_w % s[1]
2286 output_h = partial_h // s[0] + 1
2287 output_w = partial_w // s[1] + 1
2288 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01002289 if (
2290 # the parameters must produce integer exact output
2291 error_name != ErrorIf.PoolingOutputShapeNonInteger
2292 and remainder_h == 0
2293 and remainder_w == 0
2294 ) or (
2295 error_name == ErrorIf.PoolingOutputShapeNonInteger
2296 and (remainder_h != 0 or remainder_w != 0)
2297 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002298 if (
2299 max_dim_size is not None
2300 and max(output_h, output_w) > max_dim_size
2301 ):
2302 # Test will consume too much memory - skip it
2303 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002304 # Dot products = N*OH*OW*C
2305 dp = gtu.product(
2306 (shape[0], output_h, output_w, shape[3])
2307 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002308 arg_list.append(
2309 get_arg_list_element(a, s, p, k, dp, shape)
2310 )
James Ward8b390432022-08-12 20:48:56 +01002311 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002312
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002313 # Now add data generator types
2314 arg_list = TosaArgGen._add_data_generators(
2315 testGen,
2316 opName,
2317 dtype,
2318 arg_list,
2319 error_name,
2320 )
2321
2322 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002323 return arg_list
2324
2325 @staticmethod
2326 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
2327 arg_list = []
2328
2329 # Enumerate the output types here
2330 if error_name == ErrorIf.WrongOutputType:
2331 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
2332 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002333 dtypeList = [
2334 DType.BOOL,
2335 DType.INT16,
2336 DType.INT32,
2337 DType.FP16,
2338 DType.BF16,
2339 DType.FP32,
2340 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002341 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002342 dtypeList = [
2343 DType.BOOL,
2344 DType.INT8,
2345 DType.INT32,
2346 DType.FP16,
2347 DType.BF16,
2348 DType.FP32,
2349 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002350 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002351 dtypeList = [
2352 DType.BOOL,
2353 DType.INT8,
2354 DType.INT16,
2355 DType.FP16,
2356 DType.BF16,
2357 DType.FP32,
2358 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002359 elif inDtype == DType.BOOL:
2360 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002361 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00002362 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002363 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00002364 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002365 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00002366 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002367 elif error_name == ErrorIf.WrongInputType:
2368 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002369 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002370 else:
2371 raise Exception("Unexpected input dtype: {}".format(inDtype))
2372
2373 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002374 arg_list.append(
2375 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2376 )
2377
2378 # Now add data generator types
2379 arg_list = TosaArgGen._add_data_generators(
2380 testGen,
2381 opName,
2382 dtype,
2383 arg_list,
2384 error_name,
2385 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002386
2387 return arg_list
2388
2389 @staticmethod
2390 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
2391 arg_list = []
2392
2393 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002394 for outDtype in [
2395 DType.UINT8,
2396 DType.INT8,
2397 DType.INT16,
2398 DType.INT32,
2399 DType.UINT16,
2400 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002401 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002402 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002403 and error_name == ErrorIf.OutputZeroPointNotZero
2404 ):
2405 continue
2406 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002407 outDtype != DType.UINT16
2408 and error_name == ErrorIf.U16OutputZeroPointNotValid
2409 ) or (
2410 inDtype != DType.UINT16
2411 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002412 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002413 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002414 continue
2415 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002416 inDtype == DType.UINT8
2417 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002418 and error_name != ErrorIf.WrongOutputType
2419 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002420 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2421 continue
2422 if (
2423 inDtype not in [DType.INT8, DType.INT16]
2424 and outDtype == DType.UINT8
2425 and error_name != ErrorIf.WrongOutputType
2426 ):
2427 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2428 continue
2429 if (
2430 inDtype == DType.UINT16
2431 and outDtype != DType.INT16
2432 and error_name != ErrorIf.WrongOutputType
2433 ):
2434 # The only output dtype for UINT16 is INT16, skip all others
2435 continue
2436 if (
2437 inDtype != DType.INT16
2438 and outDtype == DType.UINT16
2439 and error_name != ErrorIf.WrongOutputType
2440 ):
2441 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002442 continue
2443 if (
2444 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002445 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002446 ):
2447 continue
2448
2449 for scale32 in [False, True]:
2450 if error_name == ErrorIf.ScaleTrue and not scale32:
2451 continue
2452 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2453 continue
2454 for double_round in [False, True]:
2455 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2456 continue
2457 for per_channel in [False, True]:
2458
2459 if (
2460 inDtype == DType.INT48
2461 and scale32
2462 and error_name != ErrorIf.ScaleTrue
2463 ):
2464 # Illegal condition. Must be scale32=False
2465 continue
2466 if (
2467 double_round
2468 and not scale32
2469 and error_name != ErrorIf.ScaleNotTrue
2470 ):
2471 # Illegal condition. ERROR_IF(!scale32 && double_round)
2472 continue
2473
2474 arg_list.append(
2475 (
2476 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002477 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002478 int(scale32),
2479 int(double_round),
2480 int(per_channel),
2481 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002482 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002483 )
2484 )
2485
2486 return arg_list
2487
2488 @staticmethod
2489 def agMul(testGen, opName, shapeList, dtype, error_name=None):
2490 arg_list = []
2491
2492 if dtype is DType.INT32:
2493 for p in range(testGen.args.num_rand_permutations):
2494
2495 shift = testGen.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002496 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002497 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002498 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002499
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002500 arg_list = TosaArgGen._add_data_generators(
2501 testGen,
2502 opName,
2503 dtype,
2504 arg_list,
2505 error_name,
2506 )
2507 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002508 return arg_list
2509
2510 @staticmethod
2511 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
2512 arg_list = []
2513
2514 arg_list.append(("roundTrue", [True]))
2515 arg_list.append(("roundFalse", [False]))
2516
2517 return arg_list
2518
Luke Hutton57287132023-02-06 14:54:18 +00002519 @staticmethod
2520 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2521 arg_list = []
2522
2523 arg_list.append(("inverseTrue", [True]))
2524 arg_list.append(("inverseFalse", [False]))
2525
2526 return arg_list
2527
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002528 # Helper function for reshape. Gets some factors of a larger number.
2529 @staticmethod
2530 def getFactors(val, start=1):
2531 factors = []
2532
2533 for i in range(start, int(np.sqrt(val)) + 1):
2534 if (val % i) == 0:
2535 factors.append(i)
2536
2537 return factors
2538
2539 @staticmethod
2540 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
2541 arg_list = []
2542
2543 origShape = shapeList[0]
2544
2545 totalElements = 1
2546 for s in origShape:
2547 totalElements *= s
2548
2549 # This code is NOT fast. Fortunately, the numbers are fairly small.
2550 factors = TosaArgGen.getFactors(totalElements)
2551
2552 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002553 # Rank from 1 to TOSA_TENSOR_MAX_RANK
2554 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002555 if len(factors) < newRank:
2556 continue
2557
2558 found = True
2559 # escape_counter breaks while loop if it continues on for too long
2560 escape_counter = 0
2561 while found:
2562 newShape = []
Jerry Ge264f7fa2023-04-21 22:49:57 +00002563 new_shape_inferred = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002564 # Generate newShape ensuring it isn't a duplicate
2565 remainingElements = totalElements
2566 shuffledFactors = testGen.rng.permutation(factors)
Jerry Ge264f7fa2023-04-21 22:49:57 +00002567 inferred_dim = testGen.rng.integers(1, newRank + 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002568 for i in range(1, newRank):
2569 # pick rank-1 factors
2570 newShape.append(shuffledFactors[0])
2571 remainingElements = remainingElements // shuffledFactors[0]
Jerry Ge264f7fa2023-04-21 22:49:57 +00002572 if i == inferred_dim:
2573 new_shape_inferred.append(-1)
2574 else:
2575 new_shape_inferred.append(shuffledFactors[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002576 shuffledFactors = testGen.rng.permutation(
2577 TosaArgGen.getFactors(remainingElements)
2578 )
2579 newShape.append(remainingElements)
Jerry Ge264f7fa2023-04-21 22:49:57 +00002580 if inferred_dim == newRank:
2581 new_shape_inferred.append(-1)
2582 else:
2583 new_shape_inferred.append(remainingElements)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002584
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002585 # Check for duplicates
2586 found = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002587 for name, args_dict in arg_list:
2588 if args_dict["new_shape"] == newShape:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002589 found = True
2590 break
2591
2592 escape_counter += 1
2593 if escape_counter >= 100:
2594 break
2595
2596 if not found:
Jerry Ge264f7fa2023-04-21 22:49:57 +00002597 if error_name in [
2598 ErrorIf.ReshapeOutputSizeNonInteger,
2599 ErrorIf.ReshapeOutputSizeMultiInference,
2600 ]:
2601 if newRank < 2:
2602 # Need at least two dimensions
2603 continue
2604 # NOTE: Change inferred_dim starting offset from 1 to 0
2605 inferred_dim -= 1
2606 extra_dim = inferred_dim + testGen.rng.integers(1, newRank)
2607 extra_dim = extra_dim % newRank
2608 assert extra_dim != inferred_dim
2609 if error_name == ErrorIf.ReshapeOutputSizeNonInteger:
2610 elements = 1
2611 for i, dim_value in enumerate(new_shape_inferred):
2612 if i != inferred_dim and i != extra_dim:
2613 elements *= dim_value
2614 dim_value = new_shape_inferred[extra_dim]
2615 while totalElements % (elements * dim_value) == 0:
2616 dim_value += 1
2617 new_shape_inferred[extra_dim] = dim_value
2618 else:
2619 assert error_name == ErrorIf.ReshapeOutputSizeMultiInference
2620 new_shape_inferred[extra_dim] = -1
2621 else:
2622 arg_list.append(
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002623 (
2624 "perm{}_rank{}_outdefined".format(p, newRank),
2625 {"new_shape": newShape},
2626 )
Jerry Ge264f7fa2023-04-21 22:49:57 +00002627 )
2628 if error_name != ErrorIf.TensorSizeInputOutputMismatch:
2629 arg_list.append(
2630 (
2631 "perm{}_rank{}_outinferred".format(p, newRank),
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002632 {"new_shape": new_shape_inferred},
Jerry Ge264f7fa2023-04-21 22:49:57 +00002633 )
2634 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002635
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002636 # Now add data generator types
2637 arg_list = TosaArgGen._add_data_generators(
2638 testGen,
2639 opName,
2640 dtype,
2641 arg_list,
2642 error_name,
2643 )
2644
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002645 return arg_list
2646
2647 @staticmethod
2648 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
2649 arg_list = []
2650
2651 ifm_shape = shapeList[0]
2652
2653 if error_name == ErrorIf.IndexOutsideBounds:
2654 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
2655 incorrect_small_index = range(-len(ifm_shape), 0)
2656 permutations = [p for p in itertools.permutations(incorrect_large_index)]
2657 permutations.extend(
2658 [p for p in itertools.permutations(incorrect_small_index)]
2659 )
2660 elif error_name == ErrorIf.IndexUsedTwice:
2661 # Create list with a duplicated index
2662 perm_range = list(range(len(ifm_shape)))
2663 index_choice = testGen.rng.choice(range(len(perm_range)))
2664 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
2665 permutations = [p for p in itertools.permutations(perm_range)]
2666
2667 else:
2668 # Get all permutations
2669 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
2670
2671 # Limit to possible permutations from shape dimension or argument setting
2672 limit = min(len(permutations), testGen.args.num_rand_permutations)
2673
2674 # Get random permutation generator that uses all permutations
2675 random_permutations = testGen.rng.permutation(permutations)
2676
2677 # Create list of required amount of permutations
2678 arg_list = [
2679 ("perm{}".format(p), [random_permutations[p].tolist()])
2680 for p in range(limit)
2681 ]
2682 return arg_list
2683
2684 @staticmethod
2685 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
2686 arg_list = []
2687
2688 ifm_shape = shapeList[0]
2689 rank = len(ifm_shape)
2690
2691 for p in range(testGen.args.num_rand_permutations):
2692 start = []
2693 size = []
2694
2695 valid = True
2696
2697 for i in range(rank):
2698 if ifm_shape[i] > 1:
2699 start.append(testGen.randInt(0, ifm_shape[i]))
2700 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
2701
2702 # Invalid slice size?
2703 if size[i] == 0:
2704 valid = False
2705 else:
2706 start.append(0)
2707 size.append(1)
2708
2709 if valid:
2710 # If ERROR_IF test required then incorrect start, size will be returned
2711 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
2712 testGen, error_name, ifm_shape, start, size
2713 )
2714 arg_list.append(("perm{}".format(p), [start, size]))
2715 return arg_list
2716
2717 @staticmethod
2718 def agTile(testGen, opName, shapeList, dtype, error_name=None):
2719 arg_list = []
2720
2721 ifm_shape = shapeList[0]
2722 rank = len(ifm_shape)
2723
2724 for p in range(testGen.args.num_rand_permutations):
2725
2726 # Pick a few random, but small multiple values
2727 # because otherwise this has a tendency to generate
2728 # enormous tensors
2729 multiples = []
2730 for i in range(rank):
2731 if ifm_shape[i] > 1000:
2732 # Multiple of 1 if ifm_shape dimension is large to reduce
2733 # tensor size
2734 multiples.append(1)
2735 elif max(ifm_shape) > 1000:
2736 multiples.append(2)
2737 else:
2738 multiples.append(testGen.randInt(1, 4))
2739 arg_list.append(("perm{}".format(p), [multiples]))
2740
2741 return arg_list
2742
2743 @staticmethod
2744 def agResize(testGen, opName, shapeList, dtype, error_name=None):
2745 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002746 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002747
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002748 def get_aspect_ratio_resize_params():
2749 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
2750 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
2751 invert = testGen.rng.choice((False, True))
2752 letterbox = testGen.rng.choice((False, True))
2753
2754 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
2755 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
2756 scale_y_d = scale_x_d = 1
2757 offset_x = offset_y = 0
2758
2759 if letterbox:
2760 max_border = scale_y_n
2761 border_y = testGen.randInt(low=0, high=max_border)
2762 border_x = 0
2763 else:
2764 # Pillarboxing
2765 border_y = 0
2766 max_border = scale_x_n
2767 border_x = testGen.randInt(low=0, high=max_border)
2768
2769 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2770 offset = (offset_y, offset_x)
2771 border = (border_y, border_x)
2772
2773 return scale, offset, border
2774
2775 def get_upscale_downscale_params():
2776 valid_params = False
2777 while not valid_params:
2778 upscale = testGen.rng.choice((False, True))
2779
2780 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
2781 origin_sampling = testGen.rng.choice((False, True))
2782
2783 if upscale:
2784 shift = testGen.randInt(low=1, high=4)
2785 scale_x_d = scale_y_d = 1
2786 scale_x_n = scale_y_n = (
2787 1 << shift if origin_sampling else 2 << shift
2788 )
2789 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
2790 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
2791 else:
2792 scale_x_n = 1
2793 scale_y_n = 1
2794
2795 # Return list of valid scale_*_d values (max value 4) given input dim shape
2796 def get_valid_denom(ifm_dim):
2797 return [x for x in range(1, 5) if ifm_dim % x == 1]
2798
2799 # Generate list of valid downscale values and choose one randomly
2800 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
2801 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
2802
2803 if not valid_scale_y_ds and not valid_scale_x_ds:
2804 # Bad parameters, skip
2805 continue
2806
2807 if not valid_scale_y_ds:
2808 scale_y_d = 1
2809 else:
2810 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
2811
2812 if not valid_scale_x_ds:
2813 scale_x_d = 1
2814 else:
2815 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
2816
2817 border_x = border_y = 0
2818 offset_y = testGen.randInt(0, 16 * scale_y_n)
2819 offset_x = testGen.randInt(0, 16 * scale_x_n)
2820 valid_params = True
2821
2822 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2823 offset = (offset_y, offset_x)
2824 border = (border_y, border_x)
2825 return scale, offset, border
2826
2827 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002828 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
2829 scale = scale_n / scale_d
2830 if scale > max_scale:
2831 factor = scale / max_scale
2832 new_scale_d = math.ceil(scale_d * factor)
2833 assert scale_n / new_scale_d <= max_scale
2834 scale_d = new_scale_d
2835 return scale_d
2836
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002837 # Scale
2838 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
2839 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
2840
2841 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
2842 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
2843
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002844 scale_y_d = fix_scale_to_max_scale(
2845 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2846 )
2847 scale_x_d = fix_scale_to_max_scale(
2848 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2849 )
2850
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002851 # Offsets and border within the scale
2852 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
2853 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
2854 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
2855 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
2856
2857 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2858 offset = (offset_y, offset_x)
2859 border = (border_y, border_x)
2860 return scale, offset, border
2861
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002862 def get_level_8k_params():
2863 # Create 64x scale - 64/1 to 2048/32
2864 scale_d = testGen.randInt(
2865 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
2866 )
2867 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
2868 # Create half to fifth scaling
2869 scale_d_alt = testGen.randInt(low=2, high=6)
2870 scale_n_alt = 1
2871 switch = testGen.rng.choice((False, True))
2872 if switch:
2873 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
2874 else:
2875 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
2876
2877 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
2878 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
2879 offset = (offset_y, offset_x)
2880 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
2881 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
2882 border = (border_y, border_x)
2883 return scale, offset, border
2884
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002885 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002886 # Exclude illegal {mode, type} configurations. Pick legal output types
2887 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
2888 outputDTypeList = [DType.INT8]
2889 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
2890 outputDTypeList = [DType.INT16]
2891 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
2892 outputDTypeList = [DType.INT32]
2893 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
2894 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01002895 elif dtype == DType.FP16:
2896 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01002897 elif dtype == DType.BF16:
2898 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002899 elif dtype == DType.FP32:
2900 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002901 elif error_name == ErrorIf.WrongInputType:
2902 # If an incorrect input type is used then we set a 'correct'
2903 # output type to avoid other errors
2904 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
2905 else:
2906 continue
2907
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002908 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
2909
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002910 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002911 perm = 0
2912 while perm < testGen.args.num_rand_permutations:
2913 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002914 if not testGen.args.level8k:
2915 _rnd_param_fn = testGen.rng.choice(
2916 (
2917 get_rand_params,
2918 get_upscale_downscale_params,
2919 get_aspect_ratio_resize_params,
2920 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002921 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002922 scale, offset, border = _rnd_param_fn()
2923 else:
2924 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002925
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002926 # Expand params for bounds-checking
2927 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
2928 (offset_y, offset_x) = offset
2929 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002930
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002931 # Make sure output dimensions OH and OW are integers
2932 partial_output_y = (
2933 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
2934 )
2935 partial_output_x = (
2936 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
2937 )
2938 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002939 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002940 if (
2941 partial_output_y % scale_y_d == 0
2942 and partial_output_x % scale_x_d == 0
2943 ):
2944 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002945 if perm > 0:
2946 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002947 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002948 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002949 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002950 while partial_output_y % scale_y_d != 0:
2951 scale_y_d -= 1
2952 while partial_output_x % scale_x_d != 0:
2953 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002954 # Make sure we are still within max scaling
2955 if (
2956 scale_y_n / scale_y_d
2957 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
2958 scale_x_n / scale_x_d
2959 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
2960 # Skip the test as it is using too large a scaling factor
2961 if perm > 0:
2962 perm += 1
2963 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002964
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002965 output_y = partial_output_y // scale_y_d + 1
2966 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002967
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002968 if (
2969 output_y >= testGen.args.max_resize_output_dim
2970 or output_x >= testGen.args.max_resize_output_dim
2971 ) and error_name is None:
2972 # Skip positive test if output dim will be too high
2973 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002974 if not testGen.args.level8k or perm > 0:
2975 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002976 continue
2977
2978 if (
2979 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01002980 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002981 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01002982 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002983 ):
2984 # Output dimensions out of scope
2985 if error_name is not None and perm > 0:
2986 # As long as we have one ERROR_IF test, don't worry
2987 # about creating all the other permutations
2988 perm += 1
2989 continue
2990
2991 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
2992 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01002993 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002994 and output_y - scale_y_d < 1
2995 )
2996 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01002997 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002998 and output_x - scale_x_d < 1
2999 )
3000 ):
3001 # Can't create a negative test with these params as it
3002 # will create invalid output size
3003 if perm > 0:
3004 perm += 1
3005 continue
3006
3007 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3008 offset = [offset_y, offset_x]
3009 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003010
3011 # Common for all data types
3012 if error_name is not None:
3013 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003014 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003015 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003016 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003017 outputDTypeNew,
3018 ) = TosaErrorIfArgGen.eiResizeErrorIf(
3019 testGen,
3020 error_name,
3021 mode,
3022 dtype,
3023 shapeList,
3024 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003025 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003026 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003027 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003028 )
3029 else:
3030 outputDTypeNew = outputDType
3031
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003032 arg_to_append = (
3033 arg_str.format(
3034 "N" if mode == ResizeMode.NEAREST else "B",
3035 testGen.typeStr(outputDTypeNew),
3036 scale[0],
3037 scale[1],
3038 scale[2],
3039 scale[3],
3040 offset[0],
3041 offset[1],
3042 border[0],
3043 border[1],
3044 ),
3045 [
3046 mode,
3047 scale,
3048 offset,
3049 border,
3050 dtype,
3051 outputDTypeNew,
3052 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003053 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003054 if arg_to_append in arg_list:
3055 # Skip already generated test params
3056 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003057
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003058 # Valid permutation
3059 perm += 1
3060 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003061 return arg_list
3062
3063 @staticmethod
3064 def agTable(testGen, opName, shapeList, dtype, error_name=None):
3065 arg_list = []
3066
3067 if dtype == DType.INT8:
3068 table = np.int32(
3069 testGen.rng.integers(low=-128, high=128, size=[256])
3070 ).tolist()
3071 else: # INT16
3072 table = np.int32(
3073 testGen.rng.integers(low=-32768, high=32768, size=[513])
3074 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003075 # Make sure all slopes are within REQUIRE min/max 16-bit int
3076 for idx in range(len(table) - 1):
3077 slope = table[idx + 1] - table[idx]
3078 # Alter the next table entry to force the slope to be ok
3079 if slope > 32767:
3080 table[idx + 1] -= slope - 32767
3081 if slope < -32768:
3082 table[idx + 1] -= slope + 32768
3083 slope = table[idx + 1] - table[idx]
3084 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003085 arg_list.append(
3086 (
3087 "",
3088 [table],
3089 )
3090 )
3091 return arg_list
3092
3093 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
3094 # CondIf generates the condition values here.
3095 # Convert to tensors in the build function, along with the
3096 # then and else blocks
3097 arg_list = []
3098
3099 for c in [False, True]:
3100 arg_list.append(("cond{}".format(int(c)), [c]))
3101
3102 return arg_list
3103
3104 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
3105 # While loop: 0 iterations, 1, more than 1
3106 arg_list = []
3107
3108 for iter in [0, 1, 4]:
3109 arg_list.append(("iter{}".format(iter), [iter]))
3110
3111 return arg_list