blob: 0851acad16ec118b3ecc30bd051d42effe36ac3d [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
19
20class TosaQuantGen:
21 """QuantizedInfo random generator helper functions.
22
23 Specify with 'qgen': in the operator defintion.
24 """
25
26 def __init__(self):
27 pass
28
29 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010031
32 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010033 if testGen.args.zeropoint is not None:
34 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035 return testGen.randInt(-128, 128)
36 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010037 if testGen.args.zeropoint is not None:
38 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 return testGen.randInt(0, 256)
40 elif error_name in [
41 ErrorIf.InputZeroPointNotZero,
42 ErrorIf.WeightZeroPointNotZero,
43 ErrorIf.OutputZeroPointNotZero,
44 ]:
45 zero_point = testGen.randInt(-128, 128)
46 if zero_point == 0:
47 zero_point = 1
48 return zero_point
49 return 0
50
51 @staticmethod
52 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010053 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000054 qinfo = [
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 TosaQuantGen.getZeroPoint(testGen, dtype),
57 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010058 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000059 qinfo = [
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
62 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010063 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000064 qinfo = [
65 TosaQuantGen.getZeroPoint(testGen, dtype),
66 TosaQuantGen.getZeroPoint(testGen, dtype),
67 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068 return qinfo
69
70 @staticmethod
71 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010072 if isinstance(dtype_or_dtypeList, list):
73 # a list of [input, weights, accumulator] dtypes
74 dtypeList = dtype_or_dtypeList
75 else:
76 # an int, [input, weights, accumulator] dtypes are the same
77 dtypeList = [dtype_or_dtypeList] * 3
78
79 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000080 qinfo = [
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
82 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
83 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000085 qinfo = [
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
87 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
88 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000090 qinfo = [
91 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
92 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
93 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 return qinfo
95
96 @staticmethod
97 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000099 qinfo = [
100 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
101 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
102 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000104 qinfo = [
105 TosaQuantGen.getZeroPoint(testGen, dtype),
106 TosaQuantGen.getZeroPoint(testGen, dtype),
107 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108 return qinfo
109
110 @staticmethod
111 def computeMultiplierAndShift(scaleFp, scale32):
112 # Derived from computeMultiplierAndShiftTosaScale32
113 # Provide a floating-point scaling factor and the scale32 parameter
114 # to compute the multiplier and shift
115
116 if scale32:
117 scaleBits = 31
118 else:
119 scaleBits = 15
120
121 m, shift = math.frexp(scaleFp)
122
123 if scaleFp < 0.0:
124 m = -m
125
126 multiplier = round(m * (1 << scaleBits))
127 assert multiplier <= (1 << scaleBits)
128
129 if multiplier == (1 << scaleBits):
130 multiplier = multiplier // 2
131 shift = shift + 1
132
133 shift = (-shift) + scaleBits
134 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
135 # scaleFp, scaleBits, m, multiplier, shift))
136
137 # Adjust multiplier such that shift is in allowed value range.
138 if shift == 0:
139 multiplier = multiplier // 4
140 shift = shift + 2
141 elif shift == 1:
142 multiplier = multiplier // 2
143 shift = shift + 1
144 elif shift == 63:
145 multiplier = multiplier * 2
146 shift = shift - 1
147
148 assert multiplier <= (1 << scaleBits)
149 assert shift >= 2 and shift <= 62
150
151 return multiplier, shift
152
153
154class TosaTensorGen:
155 """Tensor generators create a shape list for the placeholder and const tensor
156 data operands for the operator.
157
158 The actual random data is generated separately for each test.
159 """
160
161 def __init__(self):
162 pass
163
164 @staticmethod
165 def tgBasic(testGen, opName, rank, error_name=None):
166 pl, const = opName["operands"]
167 shape = testGen.makeShape(rank)
168
169 # Constrict the overall size of the shape when creating ERROR_IF tests
170 if error_name:
171 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
172
173 shape_list = []
174 for i in range(pl + const):
175 shape_list.append(shape.copy())
176
Luke Huttona4e48ca2023-02-22 11:53:48 +0000177 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 if error_name == ErrorIf.RankMismatch:
179 if rank == 1 and i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
181 elif i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
183
184 return shape_list
185
186 @staticmethod
187 def tgNHWC(testGen, opName, rank, error_name=None):
188 pl, const = opName["operands"]
189
190 if error_name != ErrorIf.WrongRank:
191 assert rank == 4
192
193 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000194 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name and error_name != ErrorIf.MaxDimExceeded:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000207 def tgGather(testGen, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100208 pl, const = opName["operands"]
209
210 assert pl == 2
211 assert const == 0
212 if error_name != ErrorIf.WrongRank:
213 assert rank == 3
214
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000215 values_shape = testGen.makeShape(rank)
216 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100217
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000218 N = values_shape[0]
219 W = testGen.makeDimension()
220 indices_shape = [N, W]
221
222 shape_list = [values_shape, indices_shape]
223 return shape_list
224
225 @staticmethod
226 def tgScatter(testGen, opName, rank, error_name=None):
227 pl, const = opName["operands"]
228
229 assert pl == 3
230 assert const == 0
231 if error_name != ErrorIf.WrongRank:
232 assert rank == 3
233
234 values_in_shape = testGen.makeShape(rank)
235 values_in_shape = testGen.constrictBatchSize(values_in_shape)
236
237 N = values_in_shape[0]
238 K = values_in_shape[1]
239 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100240
Jeremy Johnson194fe312023-12-07 14:17:57 +0000241 # Make sure W is not greater than K, as we can only write each output index
242 # once (having a W greater than K means that you have to repeat a K index)
243 W_min = min(testGen.args.tensor_shape_range[0], K)
244 W_max = min(testGen.args.tensor_shape_range[1], K)
245 W = testGen.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100246
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000247 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100248
249 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000250 shape_list.append(values_in_shape)
251 shape_list.append([N, W]) # indices
252 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100253
254 return shape_list
255
256 @staticmethod
257 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
258 shape = testGen.makeShape(rank)
259
260 pl, const = op["operands"]
261
262 shape_list = []
263
264 # Choose one of the inputs to broadcast
265 # Note: Simplifies OutputShaper code if we don't change first shape for errors
266 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Jerry Ge135c9552023-05-23 20:59:32 +0000267 fuzz_idx = testGen.randInt(0, rank)
268
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100269 for i in range(pl + const):
270 shape_bcast = shape.copy()
271
Jerry Ge135c9552023-05-23 20:59:32 +0000272 # To test broadcasting, the chosen fuzz index dimension should not be 1
273 if shape_bcast[fuzz_idx] == 1:
274 shape_bcast[fuzz_idx] += 1
275
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100276 # If the chosen input, pick a random index to broadcast
277 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000278 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100279 # Add one rank to the shape (or more for rank of 1)
280 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
281 shape_bcast = np.concatenate(
282 (shape_bcast, testGen.makeShape(extra_ranks))
283 )
284 if rank != 1:
285 # Either keep the extra rank, or remove it
286 new_len = testGen.rng.choice([-2, len(shape_bcast)])
287 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000288 elif error_name == ErrorIf.BroadcastShapesMismatch:
289 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100290 else:
291 shape_bcast[fuzz_idx] = 1
292
293 shape_list.append(shape_bcast)
294
295 return shape_list
296
297 @staticmethod
298 def tgConv2D(testGen, op, rank, error_name=None):
299 pl, const = op["operands"]
300
301 if error_name != ErrorIf.WrongRank:
302 assert rank == 4
303
304 # IFM dimensions are NHWC
305 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000306 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100307
308 # Constrict the overall size of the shape when creating ERROR_IF tests
309 if error_name:
310 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
311 ifm_shape, max_dim=24, max_items=10000
312 )
313
314 # Get the filter height/width from the operator parameters
315 filter_hw = op["filter"]
316
317 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000318 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100319
320 # The filter dimensions are OHWI
321 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
322
323 # The bias is OC
324 bias_shape = np.asarray([ofm_depth])
325
326 return [ifm_shape, filter_shape, bias_shape]
327
328 @staticmethod
329 def tgConv3D(testGen, op, rank, error_name=None):
330 pl, const = op["operands"]
331
332 if error_name != ErrorIf.WrongRank:
333 assert rank == 5
334
335 # IFM dimensions are NDHWC
336 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000337 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100338
339 # Constrict the overall size of the shape when creating ERROR_IF tests
340 if error_name:
341 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
342 ifm_shape, max_dim=24, max_items=10000
343 )
344
345 # Get the filter depth/height/width from the operator parameters
346 filter_dhw = op["filter"]
347
348 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000349 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100350
351 # The filter dimensions are ODHWI
352 filter_shape = np.asarray(
353 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
354 )
355
356 # The bias is OC
357 bias_shape = np.asarray([ofm_channel])
358
359 return [ifm_shape, filter_shape, bias_shape]
360
361 @staticmethod
362 def tgTransposeConv2D(testGen, op, rank, error_name=None):
363 pl, const = op["operands"]
364
365 if error_name != ErrorIf.WrongRank:
366 assert rank == 4
367
368 # IFM dimensions are NHWC
369 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000370 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100371
372 # Constrict the overall size of the shape when creating ERROR_IF tests
373 if error_name:
374 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
375 ifm_shape, max_dim=24, max_items=10000
376 )
377
378 # Get the filter height/width from the operator parameters
379 filter_hw = op["filter"]
380
381 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000382 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100383
384 # The filter dimensions are OHWI
385 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
386
387 # The bias is OC
388 bias_shape = np.asarray([ofm_depth])
389
390 return [ifm_shape, filter_shape, bias_shape]
391
392 @staticmethod
393 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
394 pl, const = op["operands"]
395
396 if error_name != ErrorIf.WrongRank:
397 assert rank == 4
398 assert pl == 1 and const == 2
399
400 # IFM dimensions are NHWC
401 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000402 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100403
404 # Constrict the overall size of the shape when creating ERROR_IF tests
405 if error_name:
406 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
407 ifm_shape, max_dim=24, max_items=10000
408 )
409
410 # Get the filter height/width from the operator parameters
411 # Filter is KH, HW, C, M
412 filter_hw = op["filter"]
413
414 # Generate a random OFM depth, but don't let it get too big because
415 # the output depth is M * C
416 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000417 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100418 ) + 1
419
420 # The filter dimensions are HWCM
421 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
422
423 # The bias is M * C
424 bias_shape = np.asarray([ifm_shape[3] * filter_m])
425
426 return [ifm_shape, filter_shape, bias_shape]
427
428 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000429 def tgFFT2d(testGen, op, rank, error_name=None):
430 pl, const = op["operands"]
431
432 if error_name != ErrorIf.WrongRank:
433 assert rank == 3
434 assert pl == 2 and const == 0
435
436 # IFM dimensions are NHW
437 ifm_shape = testGen.makeShape(rank)
438
439 # Select nearest lower power of two from input height and width
440 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
441 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
442
443 # Constrict the overall size of the shape when creating ERROR_IF tests
444 if error_name:
445 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
446
447 # Generate an invalid kernel that is not a power of two
448 if error_name == ErrorIf.KernelNotPowerOfTwo:
449 inc_h = 2 if ifm_shape[1] == 1 else 1
450 inc_w = 2 if ifm_shape[2] == 1 else 1
451 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
452 selected_inc = testGen.rng.choice(inc_choices)
453 ifm_shape[1] += selected_inc[0]
454 ifm_shape[2] += selected_inc[1]
455
456 ifm_shape = testGen.constrictBatchSize(ifm_shape)
457
458 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
459 if error_name == ErrorIf.FFTInputShapeMismatch:
460 modify_shape = testGen.rng.choice([0, 1])
461 # Only modify kernel (H, W)
462 modify_dim = testGen.rng.choice([1, 2])
463 ifm_shapes[modify_shape][modify_dim] *= 2
464
465 return [ifm_shapes[0], ifm_shapes[1]]
466
467 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000468 def tgRFFT2d(testGen, op, rank, error_name=None):
469 pl, const = op["operands"]
470
471 if error_name != ErrorIf.WrongRank:
472 assert rank == 3
473 assert pl == 1 and const == 0
474
475 # IFM dimensions are NHW
476 ifm_shape = testGen.makeShape(rank)
477
478 # Select nearest lower power of two from input height and width
479 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
480 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
481
482 # Constrict the overall size of the shape when creating ERROR_IF tests
483 if error_name:
484 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
485
486 # Generate an invalid kernel that is not a power of two
487 if error_name == ErrorIf.KernelNotPowerOfTwo:
488 # We must increment by 2 if current size is 1
489 inc_h = 2 if ifm_shape[1] == 1 else 1
490 inc_w = 2 if ifm_shape[2] == 1 else 1
491 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
492 selected_inc = testGen.rng.choice(inc_choices)
493 ifm_shape[1] += selected_inc[0]
494 ifm_shape[2] += selected_inc[1]
495
James Ward30124a82023-02-02 14:56:33 +0000496 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000497
498 return [ifm_shape]
499
500 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100501 def tgFullyConnected(testGen, op, rank, error_name=None):
502 pl, const = op["operands"]
503
504 if error_name != ErrorIf.WrongRank:
505 assert rank == 2
506
507 input_shape = testGen.makeShape(rank)
508
509 # Constrict the overall size of the shape when creating ERROR_IF tests
510 if error_name:
511 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
512
513 filter_oc = testGen.rng.integers(
514 low=testGen.args.tensor_shape_range[0],
515 high=testGen.args.tensor_shape_range[1],
516 size=1,
517 )[0]
518 filter_shape = np.asarray([filter_oc, input_shape[1]])
519
520 bias_shape = np.asarray([filter_oc])
521
522 return [input_shape, filter_shape, bias_shape]
523
524 @staticmethod
525 def tgMatmul(testGen, op, rank, error_name=None):
526 pl, const = op["operands"]
527
528 if error_name != ErrorIf.WrongRank:
529 assert rank == 3
530 assert pl == 2 and const == 0
531
532 a_shape = testGen.makeShape(rank)
533
534 # Constrict the overall size of the shape when creating ERROR_IF tests
535 if error_name:
536 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
537
538 # Get a random number for b_oc even if target shape is defined
539 b_oc = np.int32(
540 testGen.rng.integers(
541 low=testGen.args.tensor_shape_range[0],
542 high=testGen.args.tensor_shape_range[1],
543 size=1,
544 )
545 )[0]
546 # If N or H is large let b_oc be 1 to reduce output tensor size
547 if max(a_shape) > 1000:
548 b_oc = 1
549
550 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
551 return [a_shape, b_shape]
552
553 @staticmethod
554 def tgConcat(testGen, opName, rank, error_name=None):
555 pl, const = opName["operands"]
556 shape = testGen.makeShape(rank)
557
558 # Create extra tensors to concat.
559 # Take into account value of pl when getting maximum number of concats
560 num_tensors = testGen.randInt(0, 4)
561 shape_list = []
562 for i in range(pl + const + num_tensors):
563 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
564 remove = testGen.rng.choice([True, False])
565 wrongShape = shape.copy()
566
567 if remove and len(shape) > 1:
568 wrongShape = wrongShape[1:]
569 else:
570 wrongShape = list(wrongShape)
571 wrongShape.append(testGen.rng.integers(1, 10))
572
573 shape_list.append(wrongShape)
574 else:
575 shape_list.append(shape.copy())
576
577 return shape_list
578
579 @staticmethod
580 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
581 if error_name in [
582 ErrorIf.AxisSmallerZero,
583 ErrorIf.AxisLargerRank,
584 ErrorIf.ConcatInputRankMismatch,
585 ]:
586 return shapeList
587
588 # Split concat shape along axis to allow for multiple const inputs
589 # without making too many large tensors
590 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
591 # If axis can't be split we still need to invalidate other dimensions
592 if error_name == ErrorIf.ConcatInputDimMismatch:
593 for shape in shapeList[1:]:
594 # Negative test shapeLists are created individually for each test,
595 # so no need to copy the shape before altering it.
596 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
597 return shapeList
598
599 # Create copy of shape we are going to split (so we don't alter shapeList)
600 shape = shapeList[0].copy()
601 # Add original shape as first input
602 new_shapeList = [shape.copy()]
603 length_on_axis = shape[axis]
604 remaining_length = length_on_axis
605 for i in range(len(shapeList) - 2):
606 # Calculate split on axis and remaining value
607 split_shape_val = int(shape[axis] / 2)
608 remaining_length = remaining_length - split_shape_val
609
610 # Append new shape, and set remaining shape
611 shape[axis] = split_shape_val
612 new_shapeList.append(shape.copy())
613
614 # invalidate dimensions
615 if error_name == ErrorIf.ConcatInputDimMismatch:
616 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
617 else:
618 shape[axis] = remaining_length
619
620 if i == len(shapeList) - 3:
621 new_shapeList.append(shape.copy())
622
623 return new_shapeList
624
625
626class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100627 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100628
629 def __init__(self):
630 pass
631
Jeremy Johnson1271c442023-09-05 11:39:26 +0100632 class TVGInfo:
633 """Enhanced tensor values information including data gen dict."""
634
635 def __init__(self, tensorList, dataGenDict):
636 self.tensorList = tensorList
637 self.dataGenDict = dataGenDict
638
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100639 # Default high value for random numbers
640 TVG_FLOAT_HIGH_VALUE = {
641 DType.FP32: (1 << 128) - (1 << (127 - 23)),
642 DType.FP16: (1 << 16) - (1 << (15 - 10)),
643 DType.BF16: (1 << 128) - (1 << (127 - 7)),
Won Jeon2c34b462024-02-06 18:37:00 +0000644 DType.FP8E4M3: 448,
645 DType.FP8E5M2: 57344,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100646 }
647
Jeremy Johnson30476252023-11-20 16:15:30 +0000648 # Default lowest normal values for random numbers
649 TVG_FLOAT_LOW_VALUE = {
650 DType.FP32: np.exp2(-126),
651 DType.FP16: np.exp2(-14),
652 DType.BF16: np.exp2(-126),
Won Jeon2c34b462024-02-06 18:37:00 +0000653 DType.FP8E4M3: np.exp2(-9),
654 DType.FP8E5M2: np.exp2(-16),
Jeremy Johnson30476252023-11-20 16:15:30 +0000655 }
656
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100657 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +0000658 def _get_data_range(testGen, dtype, highValueLookup, lowValueLookup=None):
659 # Return a tuple of (low,high) data range values for the given data
660 # type using a combination of per operator table limits, data limits
661 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000662 if dtype in highValueLookup:
Jeremy Johnson30476252023-11-20 16:15:30 +0000663 type_range = testGen.getDTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000664 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000665 if lowValueLookup is not None and dtype in lowValueLookup:
666 low_val = lowValueLookup[dtype]
667 else:
668 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000669 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000670 # respecting the default ranges if more/less than the low/high
671 # values
672 data_range = (
673 max(low_val, type_range[0]),
674 min(high_val, type_range[1]),
675 )
676 if data_range[0] > data_range[1]:
677 # Invalid data range from low to high created due to user
678 # constraints revert to using internal ranges as they are
679 # known to work
680 msg = f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
681 warnings.warn(msg)
682 data_range = (low_val, high_val)
683 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000684 return None
685
686 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100687 def tvgLazyGenDefault(
688 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
689 ):
690 # Variable inputs versus constants
691 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000692 if "p_count" in argsDict:
693 # Override for operators like CONCAT
694 pCount = argsDict["p_count"]
695 cCount = argsDict["c_count"]
696 assert pCount + cCount == len(
697 shapeList
698 ), "Placeholders & Constant tensors must match shapes list"
699
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000700 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100701
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100702 if (
703 error_name is not None
704 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100705 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100706 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000707 # Fall back to internal data gen when dealing with unsupported types or ops
708 data_range = argsDict["data_range"] if "data_range" in argsDict else None
709 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000710 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000711 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000712 if "data_range_list" in argsDict:
713 data_range = argsDict["data_range_list"][idx]["range"]
714 roundMode = (
715 "round" in argsDict["data_range_list"][idx]
716 and argsDict["data_range_list"][idx]["round"] is True
717 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000718 if data_range is not None and dtype not in (
719 DType.FP16,
720 DType.FP32,
721 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +0000722 DType.FP8E4M3,
723 DType.FP8E5M2,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000724 ):
725 # Change from inclusive to exclusive range
726 data_range = (data_range[0], data_range[1] + 1)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000727 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000728
729 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
730 arr = np.int64(argsDict["fixed_data"][idx])
731 else:
732 arr = testGen.getRandTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000733 if roundMode:
734 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000735 if idx < pCount:
736 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
737 else:
738 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100739
Jeremy Johnson1271c442023-09-05 11:39:26 +0100740 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
741
742 # Create data generator meta-data
743 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100744 tens_data = {
745 "version": "0.1",
746 "tensors": {},
747 }
748 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100749 for idx, shape in enumerate(shapeList):
750
751 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000752 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
753 tens_meta["generator"] = gtu.DataGenType(
754 gtu.DataGenType.FIXED_DATA
755 ).name
756 else:
757 tens_meta["generator"] = gtu.DataGenType(dg_type).name
758
Jeremy Johnson1271c442023-09-05 11:39:26 +0100759 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
760 tens_meta["shape"] = [int(i) for i in shape]
761 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100762 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100763
764 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100765 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100766 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100767 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100768
769 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
770 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000771 if (
772 tens_meta["generator"]
773 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
774 ):
775 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
776 tens_meta["fixed_data_info"] = info
777 else:
778 # TODO - generate seed for this generator based on test
779 info["rng_seed"] = 42
Jeremy Johnson30476252023-11-20 16:15:30 +0000780
Won Jeon64e4bfe2024-01-18 06:31:55 +0000781 data_range = None
782 if "data_range_list" in argsDict:
783 data_range = argsDict["data_range_list"][idx]["range"]
784 if "round" in argsDict["data_range_list"][idx]:
785 info["round"] = argsDict["data_range_list"][idx]["round"]
786 elif "data_range" in argsDict:
787 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000788
Won Jeon64e4bfe2024-01-18 06:31:55 +0000789 if data_range is None:
790 data_range = testGen.getDTypeRange(
791 dtypeList[idx], high_inclusive=True
792 )
793 info["range"] = [str(v) for v in data_range]
794 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100795 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
796 info = {}
797 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100798 info["ks"] = int(argsDict["ks"])
799 if "acc_type" in argsDict:
800 # Convert type number into JSON name
801 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
802 "json"
803 ]
804 if "kernel" in argsDict:
805 info["kernel"] = [int(k) for k in argsDict["kernel"]]
806 if "axis" in argsDict:
807 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100808 tens_meta["dot_product_info"] = info
809 else:
810 # TODO - other data gen type
811 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100812
813 # Using the finished generate config meta data - generate the data if
814 # needed and assign a tensor name from the serializer
815
816 # Need to generate data when not lazy or for the bias tensor as we need
817 # to work out if the bias data is non-zero for compliance
818 if not testGen.args.lazy_data_gen or (
819 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
820 ):
821 # Give this tensor a temporary name until we get one from the serializer
822 temp_name = f"placeholder_{idx}"
823 dg_tens_meta[temp_name] = tens_meta
824 # Create data now using the temporary name to access meta details
825 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000826 if tens_meta["data_type"] == "SHAPE":
827 # Tensor type SHAPE and Numpy file type must be the same
828 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100829 # Remove the item as we will give it the correct name later
830 del dg_tens_meta[temp_name]
831
832 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
833 # The KS value used by compliance verification is altered when the
834 # bias data is non-zero
835 if max(abs(data)) > 0.0:
836 argsDict["ksb"] = argsDict["ks"] + 1
837
838 if testGen.args.lazy_data_gen:
839 data = None
840
841 if tens_meta["input_type"] == "VARIABLE":
842 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
843 else:
844 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
845
846 tens_ser_list.append(tens)
847 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100848 dg_tens_meta[tens.name] = tens_meta
849
Jeremy Johnson1271c442023-09-05 11:39:26 +0100850 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
851
852 @staticmethod
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000853 def tvgNegate(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100854 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000855 # Integer test
856 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100857 pCount, cCount = op["operands"]
858 assert (
859 pCount == 1 and cCount == 0
860 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100861 # Must create tensors with values within accumulator (int32) negatable
862 # range
863 max_val = (1 << 31) - 1
864 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100865 arr = np.int32(
866 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
867 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000868 tens_ser_list = []
869 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100870 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
871 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000872 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100873 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000874 # ERROR_IF or floating point test
875 return TosaTensorValuesGen.tvgLazyGenDefault(
876 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100877 )
878
Jeremy Johnson30476252023-11-20 16:15:30 +0000879 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000880 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
881 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
882 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
883 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
884 }
885
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100886 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000887 def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon74342e52024-01-09 00:34:40 +0000888 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000889 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100890 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000891 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100892 pCount, cCount = op["operands"]
893 assert (
894 pCount == 2 and cCount == 0
895 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000896 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000897 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
898 data_range = testGen.args.tensor_shape_range
899 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0], data_range)
900 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100901 if add:
902 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
903 else:
904 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
905
906 # Work out the saturation limits
907 max_i32 = (1 << 31) - 1
908 min_i32 = -(1 << 31)
909 max_arr = np.full(shapeList[1], max_i32)
910 min_arr = np.full(shapeList[1], min_i32)
911
912 # Find how much values exceed the maximum/minimums
913 sat_max_arr = np.maximum(res_arr - max_arr, 0)
914 sat_min_arr = np.minimum(res_arr - min_arr, 0)
915
916 if not add:
917 # Swap saturation values and negate values as we need to perform opposite operations
918 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
919
920 # Create new array of unsaturated values by clipping values as needed
921 b_unsat_arr = b_arr
922 if (sat_max_arr != 0).any():
923 # Clip values that cause saturation
924 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
925 # Reduce axes in unsaturated tensor to match original tensor
926 for axis, dim in enumerate(b_arr.shape):
927 if dim != b_unsat_arr.shape[axis]:
928 assert (
929 dim == 1
930 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
931 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
932
933 if (sat_min_arr != 0).any():
934 # Clip values that cause saturation
935 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
936 # Reduce axes in unsaturated tensor to match original tensor
937 for axis, dim in enumerate(b_arr.shape):
938 if dim != b_unsat_arr.shape[axis]:
939 assert (
940 dim == 1
941 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
942 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
943
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000944 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100945 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
946 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000947 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100948 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
949 )
950
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000951 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100952 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000953 # ERROR_IF or floating point test
954 data_range = TosaTensorValuesGen._get_data_range(
955 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
956 )
957 if data_range:
958 argsDict["data_range"] = data_range
959
960 return TosaTensorValuesGen.tvgLazyGenDefault(
961 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100962 )
963
964 @staticmethod
965 def tvgCondIfWhileLoop(
Jeremy Johnson587cc842024-02-08 11:45:44 +0000966 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100967 ):
968 if dtypeList[0] in (
969 DType.INT32,
970 DType.INT16,
971 DType.INT8,
972 ):
973 # Limit input tensors with cond_if_binary or while_loop to stop
974 # saturation of add/sub ops with int32 and keep all logical shift
975 # values between 0 to 31 for int16 or int8
Jeremy Johnson587cc842024-02-08 11:45:44 +0000976 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100977 pCount, cCount = op["operands"]
978 pRemain = pCount
Jeremy Johnson587cc842024-02-08 11:45:44 +0000979 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100980 for idx, shape in enumerate(shapeList[:]):
981 if dtypeList[0] == DType.INT32:
982 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
983 else:
984 arr = np.int32(
985 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
986 )
987 if pRemain > 0:
Jeremy Johnson587cc842024-02-08 11:45:44 +0000988 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100989 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
990 )
991 pRemain -= 1
992 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +0000993 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100994 testGen.ser.addConst(shape, dtypeList[idx], arr)
995 )
996
Jeremy Johnson587cc842024-02-08 11:45:44 +0000997 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100998 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +0000999 return TosaTensorValuesGen.tvgLazyGenDefault(
1000 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001001 )
1002
1003 @staticmethod
1004 def tvgArithmeticRightShift(
Jeremy Johnson587cc842024-02-08 11:45:44 +00001005 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001006 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00001007 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001008 pCount, cCount = op["operands"]
1009 # Force value of operand[1] to be within [0, num_bits]
1010 assert (
1011 pCount == 2 and cCount == 0
1012 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1013
Jeremy Johnson587cc842024-02-08 11:45:44 +00001014 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001015 for idx, shape in enumerate(shapeList[:]):
1016 if idx == 1:
1017 if dtypeList[idx] == DType.INT8:
1018 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
1019 elif dtypeList[idx] == DType.INT16:
1020 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
1021 elif dtypeList[idx] == DType.INT32:
1022 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
1023 elif error_name == ErrorIf.WrongInputType:
1024 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
1025 else:
1026 raise Exception("OpArithmeticRightShift: invalid input dtype")
1027 else:
1028 arr = testGen.getRandTensor(shape, dtypeList[idx])
Jeremy Johnson587cc842024-02-08 11:45:44 +00001029 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001030
Jeremy Johnson587cc842024-02-08 11:45:44 +00001031 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001032
1033 @staticmethod
Jeremy Johnson587cc842024-02-08 11:45:44 +00001034 def tvgReshape(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001035 dtypeList[1] = DType.SHAPE
1036 shapeList[1] = [len(argsDict["new_shape"])]
1037 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1038 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1039
1040 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson587cc842024-02-08 11:45:44 +00001041 testGen, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001042 )
1043
1044 @staticmethod
Jeremy Johnson587cc842024-02-08 11:45:44 +00001045 def tvgPad(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Tai Lye095da72024-01-25 22:00:18 +00001046 # argsDict["pad"] is 2D array, need to flatten it to get list of values
1047 pad_values = argsDict["pad"].flatten()
1048 dtypeList[1] = DType.SHAPE
1049 shapeList[1] = [len(pad_values)]
1050 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1051 argsDict["fixed_data"] = [None, pad_values]
1052
1053 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson587cc842024-02-08 11:45:44 +00001054 testGen, opName, dtypeList, shapeList, argsDict, error_name
Tai Lye095da72024-01-25 22:00:18 +00001055 )
1056
1057 @staticmethod
Jeremy Johnson587cc842024-02-08 11:45:44 +00001058 def tvgSlice(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
TatWai Chongf15bad82024-01-31 21:33:27 -08001059 dtypeList[1] = DType.SHAPE
1060 shapeList[1] = [len(argsDict["start"])]
1061 dtypeList[2] = DType.SHAPE
1062 shapeList[2] = [len(argsDict["size"])]
1063 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1064 argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]]
1065
1066 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson587cc842024-02-08 11:45:44 +00001067 testGen, opName, dtypeList, shapeList, argsDict, error_name
TatWai Chongf15bad82024-01-31 21:33:27 -08001068 )
1069
1070 @staticmethod
Jeremy Johnson587cc842024-02-08 11:45:44 +00001071 def tvgTile(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001072 dtypeList[1] = DType.SHAPE
1073 shapeList[1] = [len(argsDict["multiples"])]
1074 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1075
1076 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson587cc842024-02-08 11:45:44 +00001077 testGen, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001078 )
1079
1080 @staticmethod
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001081 def tvgSelect(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001082 # Set datatype of condition tensor to boolean
1083 dtypeList[0] = DType.BOOL
1084
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001085 return TosaTensorValuesGen.tvgLazyGenDefault(
1086 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001087 )
1088
1089 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001090 def tvgIntDiv(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001091 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001092 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001093 pCount, cCount = op["operands"]
1094 assert (
1095 pCount == 2 and cCount == 0
1096 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1097
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001098 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001099
1100 # Two invalid cases for Op.INTDIV:
1101 # 1. divisor == 0
1102 # 2. dividend == -(1<<31) and divisor == -1
1103 while True:
1104 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1105 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
1106
1107 if (divisor_arr == 0).any():
1108 continue
1109
1110 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1111 continue
1112
1113 break
1114
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001115 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001116 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1117 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001118 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001119 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1120 )
1121
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001122 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001123 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001124 return TosaTensorValuesGen.tvgLazyGenDefault(
1125 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001126 )
1127
Jeremy Johnson30476252023-11-20 16:15:30 +00001128 # Set the MUL data range to the square root of the largest value
1129 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001130 TVG_FLOAT_HIGH_VALUE_MUL = {
1131 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1132 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1133 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1134 }
1135
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001136 @staticmethod
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001137 def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1138 if error_name is not None or dtypeList[0] in (
1139 DType.FP16,
1140 DType.BF16,
1141 DType.FP32,
1142 ):
1143 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001144 data_range = TosaTensorValuesGen._get_data_range(
1145 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
1146 )
1147 if data_range:
1148 argsDict["data_range"] = data_range
1149
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001150 return TosaTensorValuesGen.tvgLazyGenDefault(
1151 testGen, opName, dtypeList, shapeList, argsDict, error_name
1152 )
1153 else:
1154 # Integer test
1155 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001156 pCount, cCount = op["operands"]
1157 assert (
1158 pCount == 2 and cCount == 0
1159 ), "Op.MUL must have 2 placeholders, 0 consts"
1160
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001161 tens_ser_list = []
1162
1163 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001164 if dtypeList[0] == DType.SHAPE:
1165 shift = 0
1166 else:
1167 shift = argsDict["shift"]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001168 if dtypeList[0] == DType.INT8:
1169 num_bits = 8
1170 elif dtypeList[0] == DType.INT16:
1171 num_bits = 16
Won Jeon74342e52024-01-09 00:34:40 +00001172 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001173 num_bits = 32
1174 elif error_name == ErrorIf.WrongInputType:
1175 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001176 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001177 raise Exception("OpMul: invalid input dtype")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001178
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001179 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001180 if dtypeList[idx] == DType.SHAPE:
1181 low = testGen.args.tensor_shape_range[0]
1182 high = testGen.args.tensor_shape_range[1]
1183 else:
1184 low = -(2 ** (num_bits - 1))
1185 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001186
1187 a_arr = np.int32(
1188 testGen.rng.integers(low=low, high=high, size=shapeList[0])
1189 )
1190 b_arr = np.int32(
1191 testGen.rng.integers(low=low, high=high, size=shapeList[1])
1192 )
1193
1194 i = 0
1195 while True:
1196
1197 a_arr_64 = a_arr.astype(np.int64)
1198 b_arr_64 = b_arr.astype(np.int64)
1199
1200 if shift > 0:
1201 rounding = 1 << (shift - 1)
1202 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001203 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001204 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001205
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001206 if (result_arr > -(2**31)).all() and (
1207 result_arr <= ((2**31) - 1)
1208 ).all():
1209 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001210
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001211 i = i + 1
1212 a_arr = a_arr // 2
1213 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001214
Won Jeon74342e52024-01-09 00:34:40 +00001215 if dtypeList[0] == DType.SHAPE:
1216 tens_ser_list.append(
1217 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1218 )
1219 tens_ser_list.append(
1220 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1221 )
1222 else:
1223 tens_ser_list.append(
1224 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1225 )
1226 tens_ser_list.append(
1227 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1228 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001229
1230 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001231
1232 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001233 def tvgConcat(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001234 count = len(shapeList) - testGen.args.num_const_inputs_concat
1235 if count < 1:
1236 count = 1
1237 if testGen.args.num_const_inputs_concat == 0:
1238 count = len(shapeList)
1239
Won Jeon74342e52024-01-09 00:34:40 +00001240 op = testGen.TOSA_OP_LIST[opName]
1241 if op["op"] == Op.CONCAT_SHAPE:
1242 # Set the axis to 0
1243 shapeList = TosaTensorGen.tgConcatConstInput(
1244 testGen, shapeList, 0, error_name
1245 )
1246 else:
1247 shapeList = TosaTensorGen.tgConcatConstInput(
1248 testGen, shapeList, argsDict["axis"], error_name
1249 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001250
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001251 # Override default pCount/cCount for operator
1252 argsDict["p_count"] = count
1253 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001254
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001255 return TosaTensorValuesGen.tvgLazyGenDefault(
1256 testGen, opName, dtypeList, shapeList, argsDict, error_name
1257 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001258
1259 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001260 def tvgLogicalShift(
1261 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1262 ):
1263 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001264 pCount, cCount = op["operands"]
1265 assert (
1266 pCount == 2 and cCount == 0
1267 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
1268 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1269 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001270 tens_ser_list = []
1271 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001272 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1273 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001274 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001275 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1276 )
1277
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001278 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001279
1280 @staticmethod
Jeremy Johnsona0150012023-11-15 15:52:06 +00001281 def tvgEqual(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1282 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1283 # Integer
1284 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001285 pCount, cCount = op["operands"]
1286 assert (
1287 pCount == 2 and cCount == 0
1288 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001289
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001290 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1291 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001292
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001293 # Using random numbers means that it will be very unlikely that
1294 # there are any matching (equal) values, therefore force that
1295 # there are twice the number of matching values as the tensor rank
1296 for num in range(0, len(shapeList[0]) * 2):
1297 a_index = []
1298 b_index = []
1299 # Choose an index in each axis for the whole shape
1300 for axis in range(0, len(shapeList[0])):
1301 # Index can be up to the largest dimension in both shapes
1302 index = np.int32(
1303 testGen.rng.integers(
1304 0, max(shapeList[0][axis], shapeList[1][axis])
1305 )
1306 )
1307 # Reduce the index down to a shape's dim for broadcasting
1308 a_index.append(min(shapeList[0][axis] - 1, index))
1309 b_index.append(min(shapeList[1][axis] - 1, index))
1310
1311 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1312
Jeremy Johnsona0150012023-11-15 15:52:06 +00001313 tens_ser_list = []
1314 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001315 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1316 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001317 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001318 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1319 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001320 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001321 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001322 # ERROR_IF or floating point test
1323 return TosaTensorValuesGen.tvgLazyGenDefault(
1324 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001325 )
1326
1327 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001328 def tvgReduceSum(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001329 dtype = dtypeList[0]
1330 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001331 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001332 pCount, cCount = op["operands"]
1333 assert (
1334 pCount == 1 and cCount == 0
1335 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1336 # Limit values so that the sum cannot exceed the range of an int32 during
1337 # summation of any axis
1338 range_val = int((1 << 31) / max(shapeList[0]))
1339 values_arr = np.int32(
1340 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1341 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001342 tens_ser_list = []
1343 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001344 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001345 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001346 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001347 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001348 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001349 if (
1350 error_name is None
1351 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1352 ):
1353 # Limit ranges for (non error & non compliance) tests by using
1354 # values that can be summed on any axis to not hit infinity
1355 highval_lookup = {
1356 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1357 / max(shapeList[0])
1358 }
1359 data_range = TosaTensorValuesGen._get_data_range(
1360 testGen, dtype, highval_lookup
1361 )
1362 assert data_range is not None
1363 argsDict["data_range"] = data_range
1364
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001365 return TosaTensorValuesGen.tvgLazyGenDefault(
1366 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001367 )
1368
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001369 @staticmethod
1370 def tvgReduceProduct(
1371 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1372 ):
1373 dtype = dtypeList[0]
1374 if error_name is None:
1375 # Limit ranges for (non error) tests by using
1376 # values that can be multiplied on any axis to not hit infinity
1377 highval_lookup = {
1378 dtype: math.pow(
1379 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1380 1 / max(shapeList[0]),
1381 )
1382 }
1383 data_range = TosaTensorValuesGen._get_data_range(
1384 testGen, dtype, highval_lookup
1385 )
1386 assert data_range is not None
1387 argsDict["data_range"] = data_range
1388
1389 return TosaTensorValuesGen.tvgLazyGenDefault(
1390 testGen, opName, dtypeList, shapeList, argsDict, error_name
1391 )
1392
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001393 @staticmethod
1394 def tvgResize(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1395 data_range = TosaTensorValuesGen._get_data_range(
1396 testGen,
1397 dtypeList[0],
1398 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1399 )
1400 if data_range:
1401 argsDict["data_range"] = data_range
1402 # Needed for compliance
1403 argsDict["max_abs_value"] = data_range[1]
1404
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001405 scale_values = argsDict["scale"]
1406 offset_values = argsDict["offset"]
1407 border_values = argsDict["border"]
1408 dtypeList[1] = DType.SHAPE
1409 dtypeList[2] = DType.SHAPE
1410 dtypeList[3] = DType.SHAPE
1411 shapeList[1] = [len(scale_values)]
1412 shapeList[2] = [len(offset_values)]
1413 shapeList[3] = [len(border_values)]
1414 argsDict["fixed_data"] = [None, scale_values, offset_values, border_values]
1415
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001416 return TosaTensorValuesGen.tvgLazyGenDefault(
1417 testGen, opName, dtypeList, shapeList, argsDict, error_name
1418 )
1419
Jeremy Johnson30476252023-11-20 16:15:30 +00001420 # Set the POW exponent high data range
1421 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1422 DType.FP32: 10.0,
1423 DType.FP16: 10.0,
1424 DType.BF16: 10.0,
1425 }
1426 # POW highest base value (within a safe margin of error) that can be raised
1427 # to +ve exponent that doesn't become Infinity
1428 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1429 DType.FP32: math.floor(
1430 math.pow(
1431 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1432 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1433 )
1434 ),
1435 DType.FP16: math.floor(
1436 math.pow(
1437 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1438 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1439 )
1440 ),
1441 DType.BF16: math.floor(
1442 math.pow(
1443 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1444 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1445 )
1446 ),
1447 }
1448 # POW lowest base value (within a safe margin of error) that can be raised
1449 # to -ve exponent that doesn't become Infinity
1450 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1451 DType.FP32: math.ceil(
1452 math.pow(
1453 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1454 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1455 )
1456 * 1000
1457 )
1458 / 1000,
1459 DType.FP16: math.ceil(
1460 math.pow(
1461 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1462 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1463 )
1464 * 1000
1465 )
1466 / 1000,
1467 DType.BF16: math.ceil(
1468 math.pow(
1469 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1470 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1471 )
1472 * 1000
1473 )
1474 / 1000,
1475 }
1476
1477 @staticmethod
1478 def tvgPow(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1479 if error_name is not None:
1480 return TosaTensorValuesGen.tvgLazyGenDefault(
1481 testGen, opName, dtypeList, shapeList, argsDict, error_name
1482 )
1483 dtype = dtypeList[0]
1484 # Different ranges for POW
1485 test_set = argsDict["s"]
1486 if test_set == 0:
1487 # Positive base with fractional exponent
1488 base_range = TosaTensorValuesGen._get_data_range(
1489 testGen,
1490 dtype,
1491 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1492 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1493 )
1494 exp_range = TosaTensorValuesGen._get_data_range(
1495 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1496 )
1497 exp_round = False
1498 else:
1499 # Integer exponent
1500 exp_range = TosaTensorValuesGen._get_data_range(
1501 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1502 )
1503 exp_round = True
1504 if test_set == 1:
1505 # Positive base
1506 base_range = TosaTensorValuesGen._get_data_range(
1507 testGen,
1508 dtype,
1509 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1510 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1511 )
1512 else:
1513 assert test_set == 2
1514 # Negative base
1515 # Supply new look up tables with negative values
1516 base_range = TosaTensorValuesGen._get_data_range(
1517 testGen,
1518 dtype,
1519 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1520 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1521 )
1522
1523 data_range_list = (
1524 {
1525 "range": base_range,
1526 },
1527 {
1528 "range": exp_range,
1529 "round": exp_round,
1530 },
1531 )
1532 argsDict["data_range_list"] = data_range_list
1533 return TosaTensorValuesGen.tvgLazyGenDefault(
1534 testGen, opName, dtypeList, shapeList, argsDict, error_name
1535 )
1536
1537 @staticmethod
1538 def tvgLogRsqrt(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1539 # LOG & RSQRT data range from lowest expressible positive number to
1540 # largest to avoid NaNs
1541 data_range = TosaTensorValuesGen._get_data_range(
1542 testGen,
1543 dtypeList[0],
1544 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1545 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1546 )
1547 if data_range:
1548 argsDict["data_range"] = data_range
1549
1550 return TosaTensorValuesGen.tvgLazyGenDefault(
1551 testGen, opName, dtypeList, shapeList, argsDict, error_name
1552 )
1553
1554 # Set the EXP data range to the log of the largest to smallest values
1555 # to avoid infinities or making the result zero
1556 TVG_FLOAT_HIGH_VALUE_EXP = {
1557 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1558 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1559 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1560 }
1561 TVG_FLOAT_LOW_VALUE_EXP = {
1562 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1563 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1564 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1565 }
1566
1567 @staticmethod
1568 def tvgExp(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1569 data_range = TosaTensorValuesGen._get_data_range(
1570 testGen,
1571 dtypeList[0],
1572 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1573 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1574 )
1575 if data_range:
1576 argsDict["data_range"] = data_range
1577
1578 return TosaTensorValuesGen.tvgLazyGenDefault(
1579 testGen, opName, dtypeList, shapeList, argsDict, error_name
1580 )
1581
1582 @staticmethod
1583 def tvgFullyConnected(
1584 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1585 ):
1586 dtype = dtypeList[0]
1587 if (
1588 error_name is None
1589 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001590 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001591 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001592 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001593 # Limit ranges for (non error & non compliance) FP tests by using
1594 # values that can be multiplied on any axis to not hit infinity/NaN
1595 IC = shapeList[0][1]
1596 highval_lookup = {
1597 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1598 }
1599 data_range = TosaTensorValuesGen._get_data_range(
1600 testGen, dtype, highval_lookup
1601 )
1602 assert data_range is not None
1603 argsDict["data_range"] = data_range
1604
1605 return TosaTensorValuesGen.tvgLazyGenDefault(
1606 testGen, opName, dtypeList, shapeList, argsDict, error_name
1607 )
1608
Jeremy Johnson708da822023-11-15 16:25:45 +00001609 @staticmethod
1610 def tvgCast(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1611 in_dtype = dtypeList[0]
1612 out_dtype = argsDict["out_type"]
1613 # Create look up to limit input tensor to output type maximums to avoid
1614 # FP infinities and saturation of integers
1615 out_range = testGen.getDTypeRange(out_dtype, high_inclusive=True)
1616 highval_lookup = {in_dtype: out_range[1]}
1617 data_range = TosaTensorValuesGen._get_data_range(
1618 testGen,
1619 in_dtype,
1620 highval_lookup,
1621 )
1622
1623 assert data_range is not None
1624 argsDict["data_range"] = data_range
1625
1626 return TosaTensorValuesGen.tvgLazyGenDefault(
1627 testGen, opName, dtypeList, shapeList, argsDict, error_name
1628 )
1629
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001630 @staticmethod
1631 def tvgGather(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1632 K = shapeList[0][1]
1633
1634 # Fix the type of the indices tensor
1635 dtypeList[1] = DType.INT32
1636
1637 dtype = dtypeList[0]
1638 if not gtu.dtypeIsSupportedByCompliance(dtype):
1639 # Test unsupported by data generator
1640 op = testGen.TOSA_OP_LIST[opName]
1641 pCount, cCount = op["operands"]
1642 assert (
1643 pCount == 2 and cCount == 0
1644 ), "Op.GATHER must have 2 placeholders, 0 consts"
1645
1646 tens_ser_list = []
1647 for idx, shape in enumerate(shapeList):
1648 dtype = dtypeList[idx]
1649 if idx != 1:
1650 arr = testGen.getRandTensor(shape, dtype)
1651 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1652 else:
1653 # Limit data range of indices tensor upto K (exclusive)
1654 arr = testGen.getRandTensor(shape, dtype, (0, K))
1655 # To match old functionality - create indices as CONST
1656 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1657
1658 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1659
1660 else:
1661 # ERROR_IF or floating point test
1662 # Use inclusive values upto index K for indices tensor
1663 data_range_list = (
1664 {"range": None},
1665 {"range": (0, K - 1)},
1666 )
1667 argsDict["data_range_list"] = data_range_list
1668
1669 return TosaTensorValuesGen.tvgLazyGenDefault(
1670 testGen, opName, dtypeList, shapeList, argsDict, error_name
1671 )
1672
1673 @staticmethod
1674 def tvgScatter(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1675 K = shapeList[0][1]
1676 W = shapeList[2][1]
1677
1678 # Work out an indices tensor here with data that doesn't exceed the
1679 # dimension K of the values_in tensor and does NOT repeat the same K
1680 # location as needed by the spec:
1681 # "It is not permitted to repeat the same output index within a single
1682 # SCATTER operation and so each output index occurs at most once."
1683 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1684
1685 # Fix the type of the indices tensor
1686 dtypeList[1] = DType.INT32
1687
1688 dtype = dtypeList[0]
1689 if not gtu.dtypeIsSupportedByCompliance(dtype):
1690 # Test unsupported by data generator
1691 op = testGen.TOSA_OP_LIST[opName]
1692 pCount, cCount = op["operands"]
1693 assert (
1694 pCount == 3 and cCount == 0
1695 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1696
1697 tens_ser_list = []
1698 for idx, shape in enumerate(shapeList):
1699 dtype = dtypeList[idx]
1700 if idx != 1:
1701 arr = testGen.getRandTensor(shape, dtype)
1702 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1703 else:
1704 # Create the indices array
1705 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1706 arr = []
1707 for n in range(shape[0]):
1708 # Get a shuffled list of output indices (0 to K-1) and
1709 # limit length to W
1710 arr.append(testGen.rng.permutation(K)[:W])
1711 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1712 # To match old functionality - create indices as CONST
1713 tens_ser_list.append(
1714 testGen.ser.addConst(shape, dtype, indices_arr)
1715 )
1716
1717 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1718
1719 else:
1720 # ERROR_IF or floating point test
1721 # Use inclusive values upto index K for indices tensor
1722 data_range_list = (
1723 {"range": None},
1724 {"range": (0, K - 1)},
1725 {"range": None},
1726 )
1727 argsDict["data_range_list"] = data_range_list
1728
1729 return TosaTensorValuesGen.tvgLazyGenDefault(
1730 testGen, opName, dtypeList, shapeList, argsDict, error_name
1731 )
1732
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001733
1734class TosaArgGen:
1735 """Argument generators create exhaustive or random lists of attributes for
1736 operators that take attributes or other parameters.
1737
1738 The return value is a list of (descriptive_name, [arglist]) tuples where
1739 the descriptive_name is appended to the test name and the arglist is expanded
1740 as arguments to the operator build function.
1741 """
1742
1743 def __init__(self):
1744 pass
1745
1746 @staticmethod
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001747 def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001748 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001749 if (
1750 error_name is None
1751 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1752 and gtu.dtypeIsSupportedByCompliance(dtype)
1753 ):
Won Jeon2c34b462024-02-06 18:37:00 +00001754 if dtype in [
1755 DType.FP16,
1756 DType.FP32,
1757 DType.BF16,
1758 DType.FP8E4M3,
1759 DType.FP8E5M2,
1760 ]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001761 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1762 else:
1763 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1764 else:
1765 # Error test or No data generator types listed - assume random
1766 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1767
1768 # Expand arg list with other data generator types
1769 new_arg_list = []
1770 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001771 for arg_str, args_dict in arg_list:
1772 args_dict["dg_type"] = dg_type
Jeremy Johnson1271c442023-09-05 11:39:26 +01001773 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001774 if error_name is None:
1775 num_test_sets = (
1776 args_dict["num_test_sets"]
1777 if "num_test_sets" in args_dict
1778 else 0
1779 )
1780 else:
1781 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001782
1783 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1784 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001785 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001786 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001787 shape_info = (
1788 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1789 if "shape" in args_dict
1790 else ""
1791 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001792 print(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001793 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001794 )
1795 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001796 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001797 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001798 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001799
Jeremy Johnson30476252023-11-20 16:15:30 +00001800 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1801
1802 if num_test_sets > 0:
1803 for s in range(0, num_test_sets):
1804 new_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001805 new_args_dict = args_dict.copy()
1806 new_args_dict["s"] = s
1807 new_arg_list.append((new_arg_str, new_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001808 else:
1809 # Default is a single test
1810 new_arg_list.append((arg_str, args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001811
1812 return new_arg_list
1813
1814 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001815 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1816 """A trivial argument generator for operators that don't take any
1817 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001818 arg_list = TosaArgGen._add_data_generators(
1819 testGen,
1820 opName,
1821 dtype,
1822 [("", {})],
1823 error_name,
1824 )
1825 # Return list of tuples: (arg_str, args_dict)
1826 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001827
1828 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +00001829 def agPow(testGen, opName, shapeList, dtype, error_name=None):
1830 """Pow operator needs different test sets to cover random numbers
1831 without creating NaNs or Infs"""
1832 arg_list = TosaArgGen._add_data_generators(
1833 testGen,
1834 opName,
1835 dtype,
1836 [("", {"num_test_sets": 3})],
1837 error_name,
1838 )
1839 # Return list of tuples: (arg_str, args_dict)
1840 return arg_list
1841
1842 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001843 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1844 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001845 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001846 shape = shapeList[0]
1847
1848 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001849 # Set too small axis
1850 axes = [testGen.rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001851 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001852 # Set too large axis
1853 axes = [testGen.rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001854 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001855 # Create tests for each dimension
1856 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001857
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001858 opid = testGen.TOSA_OP_LIST[opName]["op"]
1859
1860 for a in axes:
1861 args_dict = {"axis": int(a)}
1862 if opid == Op.REDUCE_SUM:
1863 args_dict["dot_products"] = gtu.product(shape)
1864 args_dict["shape"] = shape
1865 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1866 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1867
1868 arg_list.append(("axis{}".format(a), args_dict))
1869
1870 arg_list = TosaArgGen._add_data_generators(
1871 testGen,
1872 opName,
1873 dtype,
1874 arg_list,
1875 error_name,
1876 )
1877 # Return list of tuples: (arg_str, args_dict)
1878 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001879
1880 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001881 def _calculate_sparsity(num_tests, sparsity_factor):
1882 sparsity = num_tests // sparsity_factor + 1
1883 # If there are only a small number of tests, just select them all
1884 if sparsity < 13:
1885 sparsity = 1
1886 # To get a variety of parameter combinations sparsity should not be a
1887 # multiple of 2, 3 or 5
1888 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1889 sparsity += 1
1890 return sparsity
1891
1892 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001893 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001894 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001895 arg_list = []
1896
Jeremy Johnson0c716862023-04-13 17:18:19 +01001897 if testGen.args.level8k and error_name is not None:
1898 # Don't produce negative large tests
1899 return arg_list
1900
1901 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001902 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001903 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001904 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001905
Jeremy Johnson1271c442023-09-05 11:39:26 +01001906 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001907
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001908 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01001909 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001910 depthwise = opName.startswith("depthwise")
1911
1912 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001913 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001914 if error_name != ErrorIf.WrongRank:
1915 assert len(ifm_shape) == rank
1916 assert len(filter_shape) == rank
1917
Jeremy Johnson0c716862023-04-13 17:18:19 +01001918 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001919 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001920 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01001921 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001922 # compliance size - KS
1923 k_size = gtu.product(k_shape)
1924 if not depthwise:
1925 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001926
Jeremy Johnson0c716862023-04-13 17:18:19 +01001927 if not testGen.args.level8k:
1928 # Generate comprehensive argument lists
1929 # - except for named errors, which use specific invalid value(s)
1930 if error_name == ErrorIf.PadSmallerZero:
1931 p_vals = [testGen.rng.choice(range(-5, 0))]
1932 else:
1933 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1934 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1935 if error_name == ErrorIf.StrideSmallerOne:
1936 # Can't use stride=0, as it is used to derive output shape, as a divisor
1937 s_vals = [testGen.rng.choice(range(-5, 0))]
1938 else:
1939 # Stride must be greater than 1 to force non-integer error
1940 startStride = (
1941 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001942 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001943 s_vals = [
1944 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1945 ]
1946 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1947 if error_name == ErrorIf.DilationSmallerOne:
1948 d_vals = [testGen.rng.choice(range(-5, 1))]
1949 else:
1950 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1951 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001952
Jeremy Johnson0c716862023-04-13 17:18:19 +01001953 if not error_name and testGen.args.oversize:
1954 # add some oversize argument values
1955 if max(ifm_shape) < 64:
1956 bigPadding = 9
1957 paddings.update(
1958 {
1959 x
1960 for x in itertools.product(
1961 *([[0, bigPadding]] * (k_rank * 2))
1962 )
1963 }
1964 )
1965 bigStride = 8
1966 strides.update(
1967 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1968 )
1969 bigDilation = 7
1970 dilations.update(
1971 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1972 )
1973 max_dim_size = None
1974
1975 # There are too many parameter combinations, so generate them sparsely,
1976 # very sparse for negative tests
1977 sparsity_factor = 2 if error_name else 120
1978 sparsity = TosaArgGen._calculate_sparsity(
1979 len(paddings) * len(strides) * len(dilations), sparsity_factor
1980 )
1981 else:
1982 # Only test 8k levels boundaries
1983 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1984 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1985 bigPadding = bigKernel
1986
1987 dilation_shape = [1] * k_rank
1988 pad_shape = [0] * k_rank * 2
1989 if conv3d:
1990 # Small stride apart from for big kernel (see below) to keep
1991 # tensor size/calculation small
1992 stride_shape = [1] * k_rank
1993 for idx in range(k_rank):
1994 pad_offset = idx * 2
1995 if k_shape[idx] == bigKernel:
1996 # Padding shape needs to account for tensor shape
1997 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1998 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1999 # Big stride to reduce output size
2000 stride_shape[idx] = bigKernel
2001 else:
2002 # Account for kernel size
2003 pad_shape[pad_offset] = k_shape[idx] - 1
2004 else:
2005 # Always have a large stride with extra padding and dilation to keep
2006 # tensor calculation reasonable
2007 stride_shape = [bigKernel] * k_rank
2008 for idx in range(k_rank):
2009 # Dilation shape must account for kernel size
2010 dilation_shape[idx] = bigKernel // k_shape[idx]
2011 # Padding shape needs to accommodate tensor/kernel & dilation
2012 pad_offset = idx * 2
2013 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2014 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2015
2016 strides = {tuple(stride_shape)}
2017 dilations = {tuple(dilation_shape)}
2018 paddings = {tuple(pad_shape)}
2019 # Create a limit for the output dimensions size
2020 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2021
2022 # Currently allow all combinations that are reasonable size
2023 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002024
2025 n = 0
2026 for s in sorted(list(strides)):
2027 for p in sorted(list(paddings)):
2028 for d in sorted(list(dilations)):
2029 if (
2030 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01002031 # the padded shape must exceed the dilation * kernel to get a positive
2032 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002033 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
2034 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01002035 and (
2036 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01002037 or (
2038 (ifm_shape[3] - 1 + p[4] + p[5])
2039 > d[2] * (k_shape[2] - 1)
2040 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01002041 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002042 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002043 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01002044 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002045 for index in range(k_rank):
2046 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002047 partial = (
2048 ifm_shape[index + 1]
2049 - 1
2050 + p[pad_offset]
2051 + p[pad_offset + 1]
2052 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002053 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002054 remainders.append(partial % s[index])
2055 outputs.append((partial // s[index]) + 1)
2056
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002057 if (
2058 # the parameters must produce integer exact output
2059 error_name != ErrorIf.ConvOutputShapeNonInteger
2060 and max(remainders) == 0
2061 ) or (
2062 error_name == ErrorIf.ConvOutputShapeNonInteger
2063 and max(remainders) > 0
2064 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002065 if (
2066 max_dim_size is not None
2067 and max(outputs) >= max_dim_size
2068 ):
2069 # Test will consume too much memory - skip it
2070 continue
2071
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002072 # Compliance - number of dot product calculations
2073 if depthwise:
Jeremy Johnson4f931302024-01-04 17:05:24 +00002074 # N*OH*OW*C*M
2075 dots = gtu.product(
2076 (ifm_shape[0], *outputs, *filter_shape[2:])
2077 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002078 else:
Jeremy Johnson4f931302024-01-04 17:05:24 +00002079 # N*OH*OW*OC or N*OD*OH*OW*OC
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002080 dots = gtu.product(
2081 (ifm_shape[0], *outputs, filter_shape[0])
2082 )
2083 args_dict = {
2084 "acc_type": accum_dtype,
2085 "stride": s,
2086 "pad": p,
2087 "dilation": d,
2088 "kernel": k_shape,
2089 "ks": k_size,
2090 "dot_products": dots,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002091 "shape": ifm_shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002092 }
2093
Jeremy Johnson0c716862023-04-13 17:18:19 +01002094 # Support for larger values than 9 needs different delimiter
2095 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002096 arg_list.append(
2097 (
James Ward8b390432022-08-12 20:48:56 +01002098 "acc{}_st{}_pad{}_dilat{}".format(
2099 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002100 delim.join([str(x) for x in s]),
2101 delim.join([str(x) for x in p]),
2102 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002103 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002104 args_dict,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002105 )
2106 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002107 n += 1
2108
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002109 arg_list = TosaArgGen._add_data_generators(
2110 testGen,
2111 opName,
2112 dtypes[0],
2113 arg_list,
2114 error_name,
2115 )
2116 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002117 return arg_list
2118
2119 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01002120 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
2121
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002122 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002123 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002124
2125 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002126 accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002127 elif error_name == ErrorIf.WrongInputType:
2128 # Pick some potentially correct output dtype if input type is incorrect
2129 accum_dtype = DType.INT32
2130 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002131 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002132
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002133 # Set up compliance info
2134 args_dict = {
2135 "acc_type": accum_dtype,
2136 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2137 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2138 "shape": shapeList[0],
2139 }
2140
2141 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2142
2143 arg_list = TosaArgGen._add_data_generators(
2144 testGen,
2145 opName,
2146 input_dtype,
2147 arg_list,
2148 error_name,
2149 )
2150 # Return list of tuples: (arg_str, args_dict)
2151 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002152
2153 @staticmethod
2154 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
2155 # Get valid accumulate type(s)
2156 if dtype == DType.INT8:
2157 accum_dtypes = [DType.INT32]
2158 elif dtype == DType.INT16:
2159 accum_dtypes = [DType.INT48]
2160 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002161 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002162 elif dtype == DType.BF16:
2163 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002164 elif dtype == DType.FP32:
2165 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002166 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2167 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002168 elif error_name is None:
2169 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2170
2171 if error_name == ErrorIf.WrongOutputType:
2172 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson1271c442023-09-05 11:39:26 +01002173 accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002174 elif error_name == ErrorIf.WrongInputType:
2175 # Pick some potentially correct output dtype if input type is incorrect
2176 accum_dtypes = [DType.INT32]
2177
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002178 # Set up compliance info
2179 args_dict = {
2180 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2181 # Set dot_products = N*H*W
2182 "dot_products": gtu.product(
2183 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2184 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002185 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002186 }
2187
2188 # Create arg tuple of string and dict
2189 arg_list = []
2190 for a in accum_dtypes:
2191 d = args_dict.copy()
2192 d["acc_type"] = a
2193 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002194
2195 arg_list = TosaArgGen._add_data_generators(
2196 testGen,
2197 opName,
2198 dtype,
2199 arg_list,
2200 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002201 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002202 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002203 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002204
2205 @staticmethod
2206 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002207 arg_list = []
2208
Jeremy Johnson0c716862023-04-13 17:18:19 +01002209 if testGen.args.level8k and error_name is not None:
2210 # Don't produce negative large tests
2211 return arg_list
2212
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002213 ifm_shape = shapeList[0]
2214 filter_shape = shapeList[1]
2215
Jeremy Johnson1271c442023-09-05 11:39:26 +01002216 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002217
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002218 # Must be rank 4
2219 if error_name != ErrorIf.WrongRank:
2220 assert len(ifm_shape) == 4
2221 assert len(filter_shape) == 4
2222
Jeremy Johnson0c716862023-04-13 17:18:19 +01002223 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson95a67102024-01-10 14:16:39 +00002224 # compliance size - KS
2225 k_size = gtu.product((*k_shape, ifm_shape[3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002226
Jeremy Johnson0c716862023-04-13 17:18:19 +01002227 if not testGen.args.level8k:
2228 # Generate comprehensive argument lists
2229 # - except for named errors, which use specific invalid value(s)
2230 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2231 if error_name == ErrorIf.PadLargerEqualKernel:
2232 max_filter_size = -max(k_shape[0], k_shape[1])
2233 p_vals = [
2234 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
2235 ]
2236 else:
2237 p_vals = [
2238 x
2239 for x in range(
2240 smallest_padding_size, testGen.args.max_conv_padding + 1
2241 )
2242 ]
2243 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2244 if error_name == ErrorIf.StrideSmallerOne:
2245 # Can't use stride=0, as it is used to derive output shape, as a divisor
2246 s_vals = [testGen.rng.choice(range(-5, 0))]
2247 else:
2248 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2249 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002250
Jeremy Johnson0c716862023-04-13 17:18:19 +01002251 if not error_name and testGen.args.oversize:
2252 # add some oversize argument values
2253 if max(ifm_shape) < 64:
2254 bigPadding = 9
2255 paddings.update(
2256 {
2257 x
2258 for x in itertools.product(
2259 *([[smallest_padding_size, bigPadding]] * 4)
2260 )
2261 }
2262 )
2263 bigStride = 8
2264 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2265
2266 # There are too many parameter combinations, so generate them sparsely,
2267 # very sparse for negative tests
2268 sparsity_factor = 2 if error_name else 10
2269 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2270 # If there are only a small number of tests, just select them all
2271 if sparsity < 13:
2272 sparsity = 1
2273 # To get a variety of parameter combinations sparsity should not be a
2274 # multiple of 2, 3 or 5
2275 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2276 sparsity += 1
2277 else:
2278 # Only test 8k levels boundaries
2279 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2280 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2281 bigPadding = bigKernel
2282
2283 pad_shape = [0] * (len(k_shape) * 2)
2284 stride_shape = [1] * len(k_shape)
2285 # The point at which input dimension combined with the stride will
2286 # create large output sizes!
2287 LARGE_SIZE = 2
2288 for idx in range(len(k_shape)):
2289 pad_offset = idx * 2
2290 if k_shape[idx] == bigKernel:
2291 # Set large stride
2292 stride_shape[idx] = bigKernel
2293 # Use negative output padding to reduce shape size
2294 pad_shape[pad_offset] = -(bigPadding - 1)
2295 if ifm_shape[idx + 1] > LARGE_SIZE:
2296 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2297 else:
2298 # The other dimension should be the bigKernel
2299 alt_idx = 1 - idx
2300 if (
2301 k_shape[alt_idx] == bigKernel
2302 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2303 ):
2304 # As the input is small, the large stride won't
2305 # affect the output so we can add some padding
2306 pad_shape[pad_offset + 1] = bigPadding
2307
2308 strides = {tuple(stride_shape)}
2309 paddings = {tuple(pad_shape)}
2310
2311 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002312 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002313
2314 n = 0
2315 for s in sorted(list(strides)):
2316 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07002317 if n % sparsity == 0:
2318 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002319 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2320 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07002321 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002322
Jeremy Johnson95a67102024-01-10 14:16:39 +00002323 # N*OH*OW*OC
2324 dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
2325 args_dict = {
2326 "acc_type": accum_dtype,
2327 "stride": s,
2328 "pad": p,
2329 "kernel": k_shape,
2330 "ks": k_size,
2331 "dot_products": dots,
2332 "shape": ifm_shape,
2333 "out_shape": os,
2334 }
2335
Jeremy Johnson0c716862023-04-13 17:18:19 +01002336 # Support for larger values than 9 needs different delimiter
2337 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07002338 arg_list.append(
2339 (
James Ward8b390432022-08-12 20:48:56 +01002340 "acc{}_st{}_pad{}_os{}".format(
2341 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002342 delim.join([str(x) for x in s]),
2343 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07002344 "x".join([str(x) for x in os]),
2345 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00002346 args_dict,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002347 )
TatWai Chong24594f52022-06-08 00:48:04 -07002348 )
2349 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002350
Jeremy Johnson95a67102024-01-10 14:16:39 +00002351 arg_list = TosaArgGen._add_data_generators(
2352 testGen,
2353 opName,
2354 dtypes[0],
2355 arg_list,
2356 error_name,
2357 )
2358 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002359 return arg_list
2360
2361 @staticmethod
2362 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002363 rank = len(shapeList[0])
2364
2365 # Exhaustively test combinations of padding on each side of each dimension
2366 # - the range of padding values is defined by pad_min and pad_max
2367 # - for padding >9, the name format needs to be more distinctive
2368 pad_min, pad_max = 0, 1
2369 pad_values = [x for x in range(pad_min, pad_max + 1)]
2370 if error_name == ErrorIf.PadSmallerZero:
2371 pad_values = [x for x in range(-2, 0)]
2372 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2373 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2374
2375 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
2376 pad_const_int = testGen.getRandNumberDType(dtype)
2377 pad_const_fp = 0
Won Jeon2c34b462024-02-06 18:37:00 +00002378 elif dtype in (
2379 DType.FP16,
2380 DType.BF16,
2381 DType.FP32,
2382 DType.FP8E4M3,
2383 DType.FP8E5M2,
2384 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002385 pad_const_int = 0
2386 pad_const_fp = testGen.getRandNumberDType(dtype)
2387 else:
2388 return []
2389
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002390 list_shape_pad_values = list(shape_pad_values)
2391 # If we are producing tests for rank 6 or greater use sparsity
2392 if len(list_shape_pad_values) > 1024:
2393 sparsity_factor = 2 if error_name else 120
2394 sparsity = TosaArgGen._calculate_sparsity(
2395 len(list_shape_pad_values), sparsity_factor
2396 )
2397 else:
2398 sparsity = 1
2399
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002400 # Build arg list
2401 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002402 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002403 paddings = list(paddings)
2404 args_valid = True
2405
2406 if error_name == ErrorIf.PadSmallerZero:
2407 # Prevent negative output shapes while ensuring still testing for negative padding
2408 for i in range(rank):
2409 dim_after_padding = (
2410 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2411 )
2412 if dim_after_padding < 1:
2413 paddings[i] = (0, 0)
2414 if all([p > -1 for p in paddings[i]]):
2415 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002416 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01002417 name = "pad"
2418 for r in range(rank):
2419 before, after = paddings[r]
2420 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002421 args_dict = {
2422 "pad": np.array(paddings),
2423 "pad_const_int": pad_const_int,
2424 "pad_const_fp": pad_const_fp,
2425 }
2426 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002427
2428 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
2429 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002430
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002431 arg_list = TosaArgGen._add_data_generators(
2432 testGen,
2433 opName,
2434 dtype,
2435 arg_list,
2436 error_name,
2437 )
2438
2439 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002440 return arg_list
2441
2442 @staticmethod
2443 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
2444 arg_list = []
2445
2446 shape = shapeList[0]
2447 if error_name != ErrorIf.WrongRank:
2448 assert len(shape) == 4
2449
Jeremy Johnson0c716862023-04-13 17:18:19 +01002450 test_level8k = testGen.args.level8k and error_name is None
2451
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002452 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002453 startKernel = 2
2454 startPad = 0
2455 if not test_level8k:
2456 # Generate comprehensive argument lists
2457 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2458 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2459 # Stride must be greater than 1 to force non-integer error
2460 s_vals = [
2461 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2462 ]
2463 strides = {x for x in itertools.product(*([s_vals] * 2))}
2464 k_vals = [
2465 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2466 ]
2467 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2468 max_dim_size = None
2469 else:
2470 # Only test 8k levels
2471 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2472 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2473 strides = {(1, bigStride), (bigStride, 4)}
2474 kernels = {(1, bigKernel), (bigKernel, 3)}
2475 paddings = set()
2476 for s in sorted(list(strides)):
2477 for k in sorted(list(kernels)):
2478 padding = []
2479 for idx in range(len(k)):
2480 total_padding = s[idx] - shape[idx + 1] + k[idx]
2481 while total_padding < 0:
2482 # Must meet: shape + padding > kernel
2483 total_padding += s[idx]
2484 if total_padding < k[idx]:
2485 padding.extend([0, total_padding])
2486 else:
2487 # Note this may produce padding >= k[idx] which is not
2488 # allowed - but will be ignored in the creation loop below
2489 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2490 paddings.add(tuple(padding))
2491 # Create a limit for the output dimensions size
2492 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002493
James Ward8b390432022-08-12 20:48:56 +01002494 if opName == "max_pool2d":
2495 accum_dtypes = [None] # max_pool has no accumulate dtype
2496 elif dtype == DType.INT8 or dtype == DType.INT16:
2497 accum_dtypes = [DType.INT32]
2498 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002499 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002500 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002501 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002502 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2503 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002504 elif error_name is None:
2505 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2506 else:
2507 # Set to something for the ErrorIf case which has
2508 # incorrect input data-type
2509 accum_dtypes = [DType.INT32]
2510
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002511 if error_name == ErrorIf.WrongAccumulatorType:
2512 accum_dtypes = list(gtu.usableDTypes(excludes=accum_dtypes))
2513
Jeremy Johnson0c716862023-04-13 17:18:19 +01002514 if not test_level8k:
2515 if testGen.args.oversize:
2516 # add some oversize argument values
2517 bigStride = 7
2518 bigKernel = 9
2519 strides.update(
2520 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002521 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002522 kernels.update(
2523 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2524 )
2525 if max(shape) < 64:
2526 # padding must be less than the kernel size
2527 bigPadding = bigKernel - 1
2528 paddings.update(
2529 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2530 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002531
Jeremy Johnson0c716862023-04-13 17:18:19 +01002532 # There are too many parameter combinations, so generate them sparsely,
2533 # very sparse for negative tests
2534 sparsity_factor = 2 if error_name else 500
2535 sparsity = (
2536 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2537 )
2538 else:
2539 # We have already limited test output combinations for 8k tests
2540 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002541
James Ward8b390432022-08-12 20:48:56 +01002542 arg_str = (
2543 "acc{}_st{}_kern{}_pad{}"
2544 if accum_dtypes[0] is not None
2545 else "st{}_kern{}_pad{}"
2546 )
2547
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002548 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002549 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002550 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002551
2552 # Support for larger values than 9 needs different delimiter
2553 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002554 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002555 delim.join([str(x) for x in stride]),
2556 delim.join([str(x) for x in kern]),
2557 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002558 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002559 args_dict = {
2560 "stride": stride,
2561 "pad": pad,
2562 "kernel": kern,
2563 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002564 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002565 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2566 }
James Ward8b390432022-08-12 20:48:56 +01002567
2568 if accum is not None:
2569 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002570 args_dict["acc_type"] = accum
2571 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002572
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002573 n = 0
James Ward8b390432022-08-12 20:48:56 +01002574 for a in accum_dtypes:
2575 for s in sorted(list(strides)):
2576 for p in sorted(list(paddings)):
2577 for k in sorted(list(kernels)):
2578 if error_name in [
2579 ErrorIf.StrideSmallerOne,
2580 ErrorIf.KernelSmallerOne,
2581 ErrorIf.PadSmallerZero,
2582 ErrorIf.PadLargerEqualKernel,
2583 ]:
2584 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
2585 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002586 )
James Ward8b390432022-08-12 20:48:56 +01002587 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002588 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002589 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002590 )
James Ward8b390432022-08-12 20:48:56 +01002591 elif (
2592 n % sparsity == 0
2593 # padding must not exceed the kernel size
2594 and p[0] < k[0]
2595 and p[1] < k[0]
2596 and p[2] < k[1]
2597 and p[3] < k[1]
2598 # the padded shape must exceed the kernel size
2599 and (shape[1] + p[0] + p[1]) > k[0]
2600 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002601 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002602 partial_h = shape[1] + p[0] + p[1] - k[0]
2603 partial_w = shape[2] + p[2] + p[3] - k[1]
2604 remainder_h = partial_h % s[0]
2605 remainder_w = partial_w % s[1]
2606 output_h = partial_h // s[0] + 1
2607 output_w = partial_w // s[1] + 1
2608 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01002609 if (
2610 # the parameters must produce integer exact output
2611 error_name != ErrorIf.PoolingOutputShapeNonInteger
2612 and remainder_h == 0
2613 and remainder_w == 0
2614 ) or (
2615 error_name == ErrorIf.PoolingOutputShapeNonInteger
2616 and (remainder_h != 0 or remainder_w != 0)
2617 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002618 if (
2619 max_dim_size is not None
2620 and max(output_h, output_w) > max_dim_size
2621 ):
2622 # Test will consume too much memory - skip it
2623 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002624 # Dot products = N*OH*OW*C
2625 dp = gtu.product(
2626 (shape[0], output_h, output_w, shape[3])
2627 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002628 arg_list.append(
2629 get_arg_list_element(a, s, p, k, dp, shape)
2630 )
James Ward8b390432022-08-12 20:48:56 +01002631 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002632
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002633 # Now add data generator types
2634 arg_list = TosaArgGen._add_data_generators(
2635 testGen,
2636 opName,
2637 dtype,
2638 arg_list,
2639 error_name,
2640 )
2641
2642 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002643 return arg_list
2644
2645 @staticmethod
2646 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
2647 arg_list = []
2648
2649 # Enumerate the output types here
2650 if error_name == ErrorIf.WrongOutputType:
2651 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
2652 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002653 dtypeList = [
2654 DType.BOOL,
2655 DType.INT16,
2656 DType.INT32,
2657 DType.FP16,
2658 DType.BF16,
2659 DType.FP32,
2660 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002661 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002662 dtypeList = [
2663 DType.BOOL,
2664 DType.INT8,
2665 DType.INT32,
2666 DType.FP16,
2667 DType.BF16,
2668 DType.FP32,
2669 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002670 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002671 dtypeList = [
2672 DType.BOOL,
2673 DType.INT8,
2674 DType.INT16,
2675 DType.FP16,
2676 DType.BF16,
2677 DType.FP32,
2678 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002679 elif inDtype == DType.BOOL:
2680 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002681 elif inDtype == DType.FP16:
Won Jeon2c34b462024-02-06 18:37:00 +00002682 dtypeList = [
2683 DType.INT8,
2684 DType.INT16,
2685 DType.INT32,
2686 DType.FP32,
2687 DType.FP8E4M3,
2688 DType.FP8E5M2,
2689 ]
James Ward24dbc422022-10-19 12:20:31 +01002690 elif inDtype == DType.BF16:
Won Jeon2c34b462024-02-06 18:37:00 +00002691 dtypeList = [
2692 DType.INT8,
2693 DType.INT16,
2694 DType.INT32,
2695 DType.FP32,
2696 DType.FP8E4M3,
2697 DType.FP8E5M2,
2698 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002699 elif inDtype == DType.FP32:
Won Jeon2c34b462024-02-06 18:37:00 +00002700 dtypeList = [
2701 DType.INT8,
2702 DType.INT16,
2703 DType.INT32,
2704 DType.FP16,
2705 DType.BF16,
2706 DType.FP8E4M3,
2707 DType.FP8E5M2,
2708 ]
2709 elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
2710 dtypeList = [DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002711 elif error_name == ErrorIf.WrongInputType:
2712 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002713 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002714 else:
2715 raise Exception("Unexpected input dtype: {}".format(inDtype))
2716
2717 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002718 arg_list.append(
2719 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2720 )
2721
2722 # Now add data generator types
2723 arg_list = TosaArgGen._add_data_generators(
2724 testGen,
2725 opName,
2726 dtype,
2727 arg_list,
2728 error_name,
2729 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002730
2731 return arg_list
2732
2733 @staticmethod
2734 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
2735 arg_list = []
2736
2737 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002738 for outDtype in [
2739 DType.UINT8,
2740 DType.INT8,
2741 DType.INT16,
2742 DType.INT32,
2743 DType.UINT16,
2744 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002745 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002746 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002747 and error_name == ErrorIf.OutputZeroPointNotZero
2748 ):
2749 continue
2750 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002751 outDtype != DType.UINT16
2752 and error_name == ErrorIf.U16OutputZeroPointNotValid
2753 ) or (
2754 inDtype != DType.UINT16
2755 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002756 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002757 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002758 continue
2759 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002760 inDtype == DType.UINT8
2761 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002762 and error_name != ErrorIf.WrongOutputType
2763 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002764 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2765 continue
2766 if (
2767 inDtype not in [DType.INT8, DType.INT16]
2768 and outDtype == DType.UINT8
2769 and error_name != ErrorIf.WrongOutputType
2770 ):
2771 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2772 continue
2773 if (
2774 inDtype == DType.UINT16
2775 and outDtype != DType.INT16
2776 and error_name != ErrorIf.WrongOutputType
2777 ):
2778 # The only output dtype for UINT16 is INT16, skip all others
2779 continue
2780 if (
2781 inDtype != DType.INT16
2782 and outDtype == DType.UINT16
2783 and error_name != ErrorIf.WrongOutputType
2784 ):
2785 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002786 continue
2787 if (
2788 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002789 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002790 ):
2791 continue
2792
2793 for scale32 in [False, True]:
2794 if error_name == ErrorIf.ScaleTrue and not scale32:
2795 continue
2796 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2797 continue
2798 for double_round in [False, True]:
2799 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2800 continue
2801 for per_channel in [False, True]:
2802
2803 if (
2804 inDtype == DType.INT48
2805 and scale32
2806 and error_name != ErrorIf.ScaleTrue
2807 ):
2808 # Illegal condition. Must be scale32=False
2809 continue
2810 if (
2811 double_round
2812 and not scale32
2813 and error_name != ErrorIf.ScaleNotTrue
2814 ):
2815 # Illegal condition. ERROR_IF(!scale32 && double_round)
2816 continue
2817
2818 arg_list.append(
2819 (
2820 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002821 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002822 int(scale32),
2823 int(double_round),
2824 int(per_channel),
2825 ),
Jeremy Johnson587cc842024-02-08 11:45:44 +00002826 {
2827 "output_dtype": outDtype,
2828 "scale": scale32,
2829 "double_round": double_round,
2830 "per_channel": per_channel,
2831 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002832 )
2833 )
2834
Jeremy Johnson587cc842024-02-08 11:45:44 +00002835 arg_list = TosaArgGen._add_data_generators(
2836 testGen,
2837 opName,
2838 inDtype,
2839 arg_list,
2840 error_name,
2841 )
2842 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002843 return arg_list
2844
2845 @staticmethod
2846 def agMul(testGen, opName, shapeList, dtype, error_name=None):
2847 arg_list = []
2848
2849 if dtype is DType.INT32:
2850 for p in range(testGen.args.num_rand_permutations):
2851
2852 shift = testGen.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002853 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002854 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002855 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002856
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002857 arg_list = TosaArgGen._add_data_generators(
2858 testGen,
2859 opName,
2860 dtype,
2861 arg_list,
2862 error_name,
2863 )
2864 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002865 return arg_list
2866
2867 @staticmethod
2868 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
2869 arg_list = []
2870
Jeremy Johnson587cc842024-02-08 11:45:44 +00002871 for round in (True, False):
2872 args_dict = {
2873 "round": round,
2874 }
2875 arg_list.append((f"round{round}", args_dict))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002876
Jeremy Johnson587cc842024-02-08 11:45:44 +00002877 arg_list = TosaArgGen._add_data_generators(
2878 testGen,
2879 opName,
2880 dtype,
2881 arg_list,
2882 error_name,
2883 )
2884 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002885 return arg_list
2886
Luke Hutton57287132023-02-06 14:54:18 +00002887 @staticmethod
2888 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2889 arg_list = []
2890
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002891 shape = shapeList[0]
2892 dot_products = gtu.product(shape)
2893 ks = 2 * shape[1] * shape[2] # 2*H*W
2894 for inverse in (True, False):
2895 args_dict = {
2896 "dot_products": dot_products,
2897 "shape": shape,
2898 "ks": ks,
2899 "acc_type": dtype,
2900 "inverse": inverse,
2901 }
2902 arg_list.append((f"inverse{inverse}", args_dict))
Luke Hutton57287132023-02-06 14:54:18 +00002903
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002904 arg_list = TosaArgGen._add_data_generators(
2905 testGen,
2906 opName,
2907 dtype,
2908 arg_list,
2909 error_name,
2910 )
2911 # Return list of tuples: (arg_str, args_dict)
Luke Hutton57287132023-02-06 14:54:18 +00002912 return arg_list
2913
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002914 @staticmethod
2915 def agRFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2916 arg_list = []
2917
2918 shape = shapeList[0]
2919 dot_products = gtu.product(shape)
2920 ks = shape[1] * shape[2] # H*W
2921 args_dict = {
2922 "dot_products": dot_products,
2923 "shape": shape,
2924 "ks": ks,
2925 "acc_type": dtype,
2926 }
2927 arg_list.append(("", args_dict))
2928
2929 arg_list = TosaArgGen._add_data_generators(
2930 testGen,
2931 opName,
2932 dtype,
2933 arg_list,
2934 error_name,
2935 )
2936 # Return list of tuples: (arg_str, args_dict)
2937 return arg_list
2938
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002939 # Helper function for reshape. Gets some factors of a larger number.
2940 @staticmethod
2941 def getFactors(val, start=1):
2942 factors = []
2943
2944 for i in range(start, int(np.sqrt(val)) + 1):
2945 if (val % i) == 0:
2946 factors.append(i)
2947
2948 return factors
2949
2950 @staticmethod
2951 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
2952 arg_list = []
2953
2954 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002955 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002956 factors = TosaArgGen.getFactors(totalElements)
2957
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002958 # Find new shapes up to the number of permutations asked for
2959 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002960 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002961 # Rank from 1 to TOSA_TENSOR_MAX_RANK
2962 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002963 if len(factors) < newRank:
2964 continue
2965
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002966 # escape_counter limits the generation of new shapes to a reasonable time
2967 for escape_counter in range(100):
2968
2969 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002970 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002971 remainingElements = totalElements
2972 shuffledFactors = testGen.rng.permutation(factors)
2973 for i in range(1, newRank):
2974 # pick rank-1 factors
2975 newShape.append(shuffledFactors[0])
2976 remainingElements = remainingElements // shuffledFactors[0]
2977 shuffledFactors = testGen.rng.permutation(
2978 TosaArgGen.getFactors(remainingElements)
2979 )
2980 newShape.append(remainingElements)
2981
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002982 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002983 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002984 for name, args_dict in arg_list:
2985 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002986 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002987 break
2988
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002989 if not duplicate:
2990 outShape = "x".join([str(x) for x in newShape])
2991 arg_list.append(
2992 (
2993 "perm{}_rank{}_out{}".format(p, newRank, outShape),
2994 {"new_shape": newShape},
2995 )
2996 )
2997 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002998 break
2999
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003000 # Now add data generator types
3001 arg_list = TosaArgGen._add_data_generators(
3002 testGen,
3003 opName,
3004 dtype,
3005 arg_list,
3006 error_name,
3007 )
3008
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003009 return arg_list
3010
3011 @staticmethod
3012 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
3013 arg_list = []
3014
3015 ifm_shape = shapeList[0]
3016
3017 if error_name == ErrorIf.IndexOutsideBounds:
3018 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
3019 incorrect_small_index = range(-len(ifm_shape), 0)
3020 permutations = [p for p in itertools.permutations(incorrect_large_index)]
3021 permutations.extend(
3022 [p for p in itertools.permutations(incorrect_small_index)]
3023 )
3024 elif error_name == ErrorIf.IndexUsedTwice:
3025 # Create list with a duplicated index
3026 perm_range = list(range(len(ifm_shape)))
3027 index_choice = testGen.rng.choice(range(len(perm_range)))
3028 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
3029 permutations = [p for p in itertools.permutations(perm_range)]
3030
3031 else:
3032 # Get all permutations
3033 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
3034
3035 # Limit to possible permutations from shape dimension or argument setting
3036 limit = min(len(permutations), testGen.args.num_rand_permutations)
3037
3038 # Get random permutation generator that uses all permutations
3039 random_permutations = testGen.rng.permutation(permutations)
3040
3041 # Create list of required amount of permutations
3042 arg_list = [
evacha0198477222024-01-26 12:25:32 +00003043 ("perm{}".format(p), {"perms": random_permutations[p].tolist()})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003044 for p in range(limit)
3045 ]
evacha0198477222024-01-26 12:25:32 +00003046 # Now add data generator types
3047 arg_list = TosaArgGen._add_data_generators(
3048 testGen,
3049 opName,
3050 dtype,
3051 arg_list,
3052 error_name,
3053 )
3054 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003055 return arg_list
3056
3057 @staticmethod
3058 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
3059 arg_list = []
3060
3061 ifm_shape = shapeList[0]
3062 rank = len(ifm_shape)
3063
3064 for p in range(testGen.args.num_rand_permutations):
3065 start = []
3066 size = []
3067
3068 valid = True
3069
3070 for i in range(rank):
3071 if ifm_shape[i] > 1:
3072 start.append(testGen.randInt(0, ifm_shape[i]))
3073 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
3074
3075 # Invalid slice size?
3076 if size[i] == 0:
3077 valid = False
3078 else:
3079 start.append(0)
3080 size.append(1)
3081
3082 if valid:
3083 # If ERROR_IF test required then incorrect start, size will be returned
3084 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
3085 testGen, error_name, ifm_shape, start, size
3086 )
evacha017f7d4252024-01-24 12:08:09 +00003087 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
3088 # Now add data generator types
3089 arg_list = TosaArgGen._add_data_generators(
3090 testGen,
3091 opName,
3092 dtype,
3093 arg_list,
3094 error_name,
3095 )
3096 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003097 return arg_list
3098
3099 @staticmethod
3100 def agTile(testGen, opName, shapeList, dtype, error_name=None):
3101 arg_list = []
3102
3103 ifm_shape = shapeList[0]
3104 rank = len(ifm_shape)
3105
3106 for p in range(testGen.args.num_rand_permutations):
3107
3108 # Pick a few random, but small multiple values
3109 # because otherwise this has a tendency to generate
3110 # enormous tensors
3111 multiples = []
3112 for i in range(rank):
3113 if ifm_shape[i] > 1000:
3114 # Multiple of 1 if ifm_shape dimension is large to reduce
3115 # tensor size
3116 multiples.append(1)
3117 elif max(ifm_shape) > 1000:
3118 multiples.append(2)
3119 else:
3120 multiples.append(testGen.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003121 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003122
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003123 # Now add data generator types
3124 arg_list = TosaArgGen._add_data_generators(
3125 testGen,
3126 opName,
3127 dtype,
3128 arg_list,
3129 error_name,
3130 )
3131 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003132 return arg_list
3133
3134 @staticmethod
3135 def agResize(testGen, opName, shapeList, dtype, error_name=None):
3136 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003137 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003138
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003139 def get_aspect_ratio_resize_params():
3140 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
3141 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
3142 invert = testGen.rng.choice((False, True))
3143 letterbox = testGen.rng.choice((False, True))
3144
3145 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
3146 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
3147 scale_y_d = scale_x_d = 1
3148 offset_x = offset_y = 0
3149
3150 if letterbox:
3151 max_border = scale_y_n
3152 border_y = testGen.randInt(low=0, high=max_border)
3153 border_x = 0
3154 else:
3155 # Pillarboxing
3156 border_y = 0
3157 max_border = scale_x_n
3158 border_x = testGen.randInt(low=0, high=max_border)
3159
3160 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3161 offset = (offset_y, offset_x)
3162 border = (border_y, border_x)
3163
3164 return scale, offset, border
3165
3166 def get_upscale_downscale_params():
3167 valid_params = False
3168 while not valid_params:
3169 upscale = testGen.rng.choice((False, True))
3170
3171 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
3172 origin_sampling = testGen.rng.choice((False, True))
3173
3174 if upscale:
3175 shift = testGen.randInt(low=1, high=4)
3176 scale_x_d = scale_y_d = 1
3177 scale_x_n = scale_y_n = (
3178 1 << shift if origin_sampling else 2 << shift
3179 )
3180 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3181 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3182 else:
3183 scale_x_n = 1
3184 scale_y_n = 1
3185
3186 # Return list of valid scale_*_d values (max value 4) given input dim shape
3187 def get_valid_denom(ifm_dim):
3188 return [x for x in range(1, 5) if ifm_dim % x == 1]
3189
3190 # Generate list of valid downscale values and choose one randomly
3191 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3192 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3193
3194 if not valid_scale_y_ds and not valid_scale_x_ds:
3195 # Bad parameters, skip
3196 continue
3197
3198 if not valid_scale_y_ds:
3199 scale_y_d = 1
3200 else:
3201 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
3202
3203 if not valid_scale_x_ds:
3204 scale_x_d = 1
3205 else:
3206 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
3207
3208 border_x = border_y = 0
3209 offset_y = testGen.randInt(0, 16 * scale_y_n)
3210 offset_x = testGen.randInt(0, 16 * scale_x_n)
3211 valid_params = True
3212
3213 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3214 offset = (offset_y, offset_x)
3215 border = (border_y, border_x)
3216 return scale, offset, border
3217
3218 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003219 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3220 scale = scale_n / scale_d
3221 if scale > max_scale:
3222 factor = scale / max_scale
3223 new_scale_d = math.ceil(scale_d * factor)
3224 assert scale_n / new_scale_d <= max_scale
3225 scale_d = new_scale_d
3226 return scale_d
3227
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003228 # Scale
3229 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
3230 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
3231
3232 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
3233 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
3234
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003235 scale_y_d = fix_scale_to_max_scale(
3236 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3237 )
3238 scale_x_d = fix_scale_to_max_scale(
3239 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3240 )
3241
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003242 # Offsets and border within the scale
3243 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3244 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3245 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3246 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
3247
3248 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3249 offset = (offset_y, offset_x)
3250 border = (border_y, border_x)
3251 return scale, offset, border
3252
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003253 def get_level_8k_params():
3254 # Create 64x scale - 64/1 to 2048/32
3255 scale_d = testGen.randInt(
3256 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3257 )
3258 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3259 # Create half to fifth scaling
3260 scale_d_alt = testGen.randInt(low=2, high=6)
3261 scale_n_alt = 1
3262 switch = testGen.rng.choice((False, True))
3263 if switch:
3264 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3265 else:
3266 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3267
3268 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3269 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
3270 offset = (offset_y, offset_x)
3271 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
3272 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
3273 border = (border_y, border_x)
3274 return scale, offset, border
3275
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003276 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003277 # Exclude illegal {mode, type} configurations. Pick legal output types
3278 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3279 outputDTypeList = [DType.INT8]
3280 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3281 outputDTypeList = [DType.INT16]
3282 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3283 outputDTypeList = [DType.INT32]
3284 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3285 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003286 elif dtype == DType.FP16:
3287 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003288 elif dtype == DType.BF16:
3289 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003290 elif dtype == DType.FP32:
3291 outputDTypeList = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00003292 elif dtype == DType.FP8E4M3:
3293 outputDTypeList = [DType.FP8E4M3]
3294 elif dtype == DType.FP8E5M2:
3295 outputDTypeList = [DType.FP8E5M2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003296 elif error_name == ErrorIf.WrongInputType:
3297 # If an incorrect input type is used then we set a 'correct'
3298 # output type to avoid other errors
3299 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3300 else:
3301 continue
3302
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003303 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3304
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003305 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003306 perm = 0
3307 while perm < testGen.args.num_rand_permutations:
3308 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003309 if not testGen.args.level8k:
3310 _rnd_param_fn = testGen.rng.choice(
3311 (
3312 get_rand_params,
3313 get_upscale_downscale_params,
3314 get_aspect_ratio_resize_params,
3315 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003316 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003317 scale, offset, border = _rnd_param_fn()
3318 else:
3319 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003320
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003321 # Expand params for bounds-checking
3322 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3323 (offset_y, offset_x) = offset
3324 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003325
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003326 # Make sure output dimensions OH and OW are integers
3327 partial_output_y = (
3328 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3329 )
3330 partial_output_x = (
3331 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3332 )
3333 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003334 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003335 if (
3336 partial_output_y % scale_y_d == 0
3337 and partial_output_x % scale_x_d == 0
3338 ):
3339 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003340 if perm > 0:
3341 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003342 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003343 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003344 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003345 while partial_output_y % scale_y_d != 0:
3346 scale_y_d -= 1
3347 while partial_output_x % scale_x_d != 0:
3348 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003349 # Make sure we are still within max scaling
3350 if (
3351 scale_y_n / scale_y_d
3352 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3353 scale_x_n / scale_x_d
3354 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3355 # Skip the test as it is using too large a scaling factor
3356 if perm > 0:
3357 perm += 1
3358 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003359
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003360 output_y = partial_output_y // scale_y_d + 1
3361 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003362
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003363 if (
3364 output_y >= testGen.args.max_resize_output_dim
3365 or output_x >= testGen.args.max_resize_output_dim
3366 ) and error_name is None:
3367 # Skip positive test if output dim will be too high
3368 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003369 if not testGen.args.level8k or perm > 0:
3370 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003371 continue
3372
3373 if (
3374 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003375 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003376 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003377 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003378 ):
3379 # Output dimensions out of scope
3380 if error_name is not None and perm > 0:
3381 # As long as we have one ERROR_IF test, don't worry
3382 # about creating all the other permutations
3383 perm += 1
3384 continue
3385
3386 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3387 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003388 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003389 and output_y - scale_y_d < 1
3390 )
3391 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003392 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003393 and output_x - scale_x_d < 1
3394 )
3395 ):
3396 # Can't create a negative test with these params as it
3397 # will create invalid output size
3398 if perm > 0:
3399 perm += 1
3400 continue
3401
3402 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3403 offset = [offset_y, offset_x]
3404 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003405
3406 # Common for all data types
3407 if error_name is not None:
3408 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003409 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003410 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003411 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003412 outputDTypeNew,
3413 ) = TosaErrorIfArgGen.eiResizeErrorIf(
3414 testGen,
3415 error_name,
3416 mode,
3417 dtype,
3418 shapeList,
3419 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003420 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003421 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003422 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003423 )
3424 else:
3425 outputDTypeNew = outputDType
3426
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003427 arg_to_append = (
3428 arg_str.format(
3429 "N" if mode == ResizeMode.NEAREST else "B",
3430 testGen.typeStr(outputDTypeNew),
3431 scale[0],
3432 scale[1],
3433 scale[2],
3434 scale[3],
3435 offset[0],
3436 offset[1],
3437 border[0],
3438 border[1],
3439 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003440 {
3441 "mode": mode,
3442 "scale": scale,
3443 "offset": offset,
3444 "border": border,
3445 "output_dtype": outputDTypeNew,
3446 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003447 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003448 if arg_to_append in arg_list:
3449 # Skip already generated test params
3450 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003451
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003452 # Valid permutation
3453 perm += 1
3454 arg_list.append(arg_to_append)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003455
3456 # Now add data generator types
3457 arg_list = TosaArgGen._add_data_generators(
3458 testGen,
3459 opName,
3460 dtype,
3461 arg_list,
3462 error_name,
3463 )
3464 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003465 return arg_list
3466
3467 @staticmethod
3468 def agTable(testGen, opName, shapeList, dtype, error_name=None):
3469 arg_list = []
3470
3471 if dtype == DType.INT8:
3472 table = np.int32(
3473 testGen.rng.integers(low=-128, high=128, size=[256])
3474 ).tolist()
3475 else: # INT16
3476 table = np.int32(
3477 testGen.rng.integers(low=-32768, high=32768, size=[513])
3478 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003479 # Make sure all slopes are within REQUIRE min/max 16-bit int
3480 for idx in range(len(table) - 1):
3481 slope = table[idx + 1] - table[idx]
3482 # Alter the next table entry to force the slope to be ok
3483 if slope > 32767:
3484 table[idx + 1] -= slope - 32767
3485 if slope < -32768:
3486 table[idx + 1] -= slope + 32768
3487 slope = table[idx + 1] - table[idx]
3488 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003489 arg_list.append(
3490 (
3491 "",
Jeremy Johnson587cc842024-02-08 11:45:44 +00003492 {"table": table},
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003493 )
3494 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003495 # Now add data generator types
3496 arg_list = TosaArgGen._add_data_generators(
3497 testGen,
3498 opName,
3499 dtype,
3500 arg_list,
3501 error_name,
3502 )
3503 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003504 return arg_list
3505
3506 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
3507 # CondIf generates the condition values here.
3508 # Convert to tensors in the build function, along with the
3509 # then and else blocks
3510 arg_list = []
3511
3512 for c in [False, True]:
Jeremy Johnson587cc842024-02-08 11:45:44 +00003513 arg_list.append(("cond{}".format(int(c)), {"condition": c}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003514
Jeremy Johnson587cc842024-02-08 11:45:44 +00003515 # Now add data generator types
3516 arg_list = TosaArgGen._add_data_generators(
3517 testGen,
3518 opName,
3519 dtype,
3520 arg_list,
3521 error_name,
3522 )
3523 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003524 return arg_list
3525
3526 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
3527 # While loop: 0 iterations, 1, more than 1
3528 arg_list = []
3529
Jeremy Johnson587cc842024-02-08 11:45:44 +00003530 for iterations in [0, 1, 4]:
3531 arg_list.append(("iter{}".format(iterations), {"iterations": iterations}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003532
Jeremy Johnson587cc842024-02-08 11:45:44 +00003533 # Now add data generator types
3534 arg_list = TosaArgGen._add_data_generators(
3535 testGen,
3536 opName,
3537 dtype,
3538 arg_list,
3539 error_name,
3540 )
3541 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003542 return arg_list