blob: 592c49195da0fc539bff7235d305855847574612 [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
19
20class TosaQuantGen:
21 """QuantizedInfo random generator helper functions.
22
23 Specify with 'qgen': in the operator defintion.
24 """
25
26 def __init__(self):
27 pass
28
29 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010031
32 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010033 if testGen.args.zeropoint is not None:
34 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035 return testGen.randInt(-128, 128)
36 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010037 if testGen.args.zeropoint is not None:
38 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 return testGen.randInt(0, 256)
40 elif error_name in [
41 ErrorIf.InputZeroPointNotZero,
42 ErrorIf.WeightZeroPointNotZero,
43 ErrorIf.OutputZeroPointNotZero,
44 ]:
45 zero_point = testGen.randInt(-128, 128)
46 if zero_point == 0:
47 zero_point = 1
48 return zero_point
49 return 0
50
51 @staticmethod
52 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010053 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000054 qinfo = [
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 TosaQuantGen.getZeroPoint(testGen, dtype),
57 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010058 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000059 qinfo = [
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
62 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010063 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000064 qinfo = [
65 TosaQuantGen.getZeroPoint(testGen, dtype),
66 TosaQuantGen.getZeroPoint(testGen, dtype),
67 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068 return qinfo
69
70 @staticmethod
71 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010072 if isinstance(dtype_or_dtypeList, list):
73 # a list of [input, weights, accumulator] dtypes
74 dtypeList = dtype_or_dtypeList
75 else:
76 # an int, [input, weights, accumulator] dtypes are the same
77 dtypeList = [dtype_or_dtypeList] * 3
78
79 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000080 qinfo = [
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
82 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
83 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000085 qinfo = [
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
87 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
88 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000090 qinfo = [
91 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
92 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
93 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 return qinfo
95
96 @staticmethod
97 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000099 qinfo = [
100 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
101 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
102 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000104 qinfo = [
105 TosaQuantGen.getZeroPoint(testGen, dtype),
106 TosaQuantGen.getZeroPoint(testGen, dtype),
107 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108 return qinfo
109
110 @staticmethod
111 def computeMultiplierAndShift(scaleFp, scale32):
112 # Derived from computeMultiplierAndShiftTosaScale32
113 # Provide a floating-point scaling factor and the scale32 parameter
114 # to compute the multiplier and shift
115
116 if scale32:
117 scaleBits = 31
118 else:
119 scaleBits = 15
120
121 m, shift = math.frexp(scaleFp)
122
123 if scaleFp < 0.0:
124 m = -m
125
126 multiplier = round(m * (1 << scaleBits))
127 assert multiplier <= (1 << scaleBits)
128
129 if multiplier == (1 << scaleBits):
130 multiplier = multiplier // 2
131 shift = shift + 1
132
133 shift = (-shift) + scaleBits
134 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
135 # scaleFp, scaleBits, m, multiplier, shift))
136
137 # Adjust multiplier such that shift is in allowed value range.
138 if shift == 0:
139 multiplier = multiplier // 4
140 shift = shift + 2
141 elif shift == 1:
142 multiplier = multiplier // 2
143 shift = shift + 1
144 elif shift == 63:
145 multiplier = multiplier * 2
146 shift = shift - 1
147
148 assert multiplier <= (1 << scaleBits)
149 assert shift >= 2 and shift <= 62
150
151 return multiplier, shift
152
153
154class TosaTensorGen:
155 """Tensor generators create a shape list for the placeholder and const tensor
156 data operands for the operator.
157
158 The actual random data is generated separately for each test.
159 """
160
161 def __init__(self):
162 pass
163
164 @staticmethod
165 def tgBasic(testGen, opName, rank, error_name=None):
166 pl, const = opName["operands"]
167 shape = testGen.makeShape(rank)
168
169 # Constrict the overall size of the shape when creating ERROR_IF tests
170 if error_name:
171 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
172
173 shape_list = []
174 for i in range(pl + const):
175 shape_list.append(shape.copy())
176
Luke Huttona4e48ca2023-02-22 11:53:48 +0000177 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 if error_name == ErrorIf.RankMismatch:
179 if rank == 1 and i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
181 elif i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
183
184 return shape_list
185
186 @staticmethod
187 def tgNHWC(testGen, opName, rank, error_name=None):
188 pl, const = opName["operands"]
189
190 if error_name != ErrorIf.WrongRank:
191 assert rank == 4
192
193 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000194 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name and error_name != ErrorIf.MaxDimExceeded:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000207 def tgGather(testGen, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100208 pl, const = opName["operands"]
209
210 assert pl == 2
211 assert const == 0
212 if error_name != ErrorIf.WrongRank:
213 assert rank == 3
214
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000215 values_shape = testGen.makeShape(rank)
216 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100217
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000218 N = values_shape[0]
219 W = testGen.makeDimension()
220 indices_shape = [N, W]
221
222 shape_list = [values_shape, indices_shape]
223 return shape_list
224
225 @staticmethod
226 def tgScatter(testGen, opName, rank, error_name=None):
227 pl, const = opName["operands"]
228
229 assert pl == 3
230 assert const == 0
231 if error_name != ErrorIf.WrongRank:
232 assert rank == 3
233
234 values_in_shape = testGen.makeShape(rank)
235 values_in_shape = testGen.constrictBatchSize(values_in_shape)
236
237 N = values_in_shape[0]
238 K = values_in_shape[1]
239 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100240
Jeremy Johnson194fe312023-12-07 14:17:57 +0000241 # Make sure W is not greater than K, as we can only write each output index
242 # once (having a W greater than K means that you have to repeat a K index)
243 W_min = min(testGen.args.tensor_shape_range[0], K)
244 W_max = min(testGen.args.tensor_shape_range[1], K)
245 W = testGen.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100246
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000247 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100248
249 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000250 shape_list.append(values_in_shape)
251 shape_list.append([N, W]) # indices
252 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100253
254 return shape_list
255
256 @staticmethod
Jeremy Johnson0a042992024-02-28 13:20:05 +0000257 def _get_broadcast_shapes(testGen, num_shapes, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100258 shape = testGen.makeShape(rank)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100259 shape_list = []
260
261 # Choose one of the inputs to broadcast
262 # Note: Simplifies OutputShaper code if we don't change first shape for errors
Jeremy Johnson0a042992024-02-28 13:20:05 +0000263 bcast_idx = testGen.randInt(0 if error_name is None else 1, num_shapes)
Jerry Ge135c9552023-05-23 20:59:32 +0000264 fuzz_idx = testGen.randInt(0, rank)
265
Jeremy Johnson0a042992024-02-28 13:20:05 +0000266 for i in range(num_shapes):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100267 shape_bcast = shape.copy()
268
Jerry Ge135c9552023-05-23 20:59:32 +0000269 # To test broadcasting, the chosen fuzz index dimension should not be 1
270 if shape_bcast[fuzz_idx] == 1:
271 shape_bcast[fuzz_idx] += 1
272
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100273 # If the chosen input, pick a random index to broadcast
274 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000275 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100276 # Add one rank to the shape (or more for rank of 1)
277 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
278 shape_bcast = np.concatenate(
279 (shape_bcast, testGen.makeShape(extra_ranks))
280 )
281 if rank != 1:
282 # Either keep the extra rank, or remove it
283 new_len = testGen.rng.choice([-2, len(shape_bcast)])
284 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000285 elif error_name == ErrorIf.BroadcastShapesMismatch:
286 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100287 else:
288 shape_bcast[fuzz_idx] = 1
289
290 shape_list.append(shape_bcast)
291
292 return shape_list
293
294 @staticmethod
Jeremy Johnson0a042992024-02-28 13:20:05 +0000295 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
296 pl, const = op["operands"]
297 num_shapes = pl + const
298 return TosaTensorGen._get_broadcast_shapes(
299 testGen, num_shapes, rank, error_name
300 )
301
302 @staticmethod
303 def tgMul(testGen, op, rank, error_name=None):
304 # Get broadcast shapes for the first 2 inputs as the 3rd is shift
305 shape_list = TosaTensorGen._get_broadcast_shapes(testGen, 2, rank, error_name)
306 # Add a single dimension tensor for shift
307 shape_list.append([1])
308 return shape_list
309
310 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100311 def tgConv2D(testGen, op, rank, error_name=None):
312 pl, const = op["operands"]
313
314 if error_name != ErrorIf.WrongRank:
315 assert rank == 4
316
317 # IFM dimensions are NHWC
318 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000319 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100320
321 # Constrict the overall size of the shape when creating ERROR_IF tests
322 if error_name:
323 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
324 ifm_shape, max_dim=24, max_items=10000
325 )
326
327 # Get the filter height/width from the operator parameters
328 filter_hw = op["filter"]
329
330 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000331 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100332
333 # The filter dimensions are OHWI
334 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
335
336 # The bias is OC
337 bias_shape = np.asarray([ofm_depth])
338
339 return [ifm_shape, filter_shape, bias_shape]
340
341 @staticmethod
342 def tgConv3D(testGen, op, rank, error_name=None):
343 pl, const = op["operands"]
344
345 if error_name != ErrorIf.WrongRank:
346 assert rank == 5
347
348 # IFM dimensions are NDHWC
349 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000350 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100351
352 # Constrict the overall size of the shape when creating ERROR_IF tests
353 if error_name:
354 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
355 ifm_shape, max_dim=24, max_items=10000
356 )
357
358 # Get the filter depth/height/width from the operator parameters
359 filter_dhw = op["filter"]
360
361 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000362 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100363
364 # The filter dimensions are ODHWI
365 filter_shape = np.asarray(
366 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
367 )
368
369 # The bias is OC
370 bias_shape = np.asarray([ofm_channel])
371
372 return [ifm_shape, filter_shape, bias_shape]
373
374 @staticmethod
375 def tgTransposeConv2D(testGen, op, rank, error_name=None):
376 pl, const = op["operands"]
377
378 if error_name != ErrorIf.WrongRank:
379 assert rank == 4
380
381 # IFM dimensions are NHWC
382 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000383 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100384
385 # Constrict the overall size of the shape when creating ERROR_IF tests
386 if error_name:
387 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
388 ifm_shape, max_dim=24, max_items=10000
389 )
390
391 # Get the filter height/width from the operator parameters
392 filter_hw = op["filter"]
393
394 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000395 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100396
397 # The filter dimensions are OHWI
398 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
399
400 # The bias is OC
401 bias_shape = np.asarray([ofm_depth])
402
403 return [ifm_shape, filter_shape, bias_shape]
404
405 @staticmethod
406 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
407 pl, const = op["operands"]
408
409 if error_name != ErrorIf.WrongRank:
410 assert rank == 4
411 assert pl == 1 and const == 2
412
413 # IFM dimensions are NHWC
414 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000415 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100416
417 # Constrict the overall size of the shape when creating ERROR_IF tests
418 if error_name:
419 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
420 ifm_shape, max_dim=24, max_items=10000
421 )
422
423 # Get the filter height/width from the operator parameters
424 # Filter is KH, HW, C, M
425 filter_hw = op["filter"]
426
427 # Generate a random OFM depth, but don't let it get too big because
428 # the output depth is M * C
429 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000430 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100431 ) + 1
432
433 # The filter dimensions are HWCM
434 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
435
436 # The bias is M * C
437 bias_shape = np.asarray([ifm_shape[3] * filter_m])
438
439 return [ifm_shape, filter_shape, bias_shape]
440
441 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000442 def tgFFT2d(testGen, op, rank, error_name=None):
443 pl, const = op["operands"]
444
445 if error_name != ErrorIf.WrongRank:
446 assert rank == 3
447 assert pl == 2 and const == 0
448
449 # IFM dimensions are NHW
450 ifm_shape = testGen.makeShape(rank)
451
452 # Select nearest lower power of two from input height and width
453 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
454 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
455
456 # Constrict the overall size of the shape when creating ERROR_IF tests
457 if error_name:
458 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
459
460 # Generate an invalid kernel that is not a power of two
461 if error_name == ErrorIf.KernelNotPowerOfTwo:
462 inc_h = 2 if ifm_shape[1] == 1 else 1
463 inc_w = 2 if ifm_shape[2] == 1 else 1
464 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
465 selected_inc = testGen.rng.choice(inc_choices)
466 ifm_shape[1] += selected_inc[0]
467 ifm_shape[2] += selected_inc[1]
468
469 ifm_shape = testGen.constrictBatchSize(ifm_shape)
470
471 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
472 if error_name == ErrorIf.FFTInputShapeMismatch:
473 modify_shape = testGen.rng.choice([0, 1])
474 # Only modify kernel (H, W)
475 modify_dim = testGen.rng.choice([1, 2])
476 ifm_shapes[modify_shape][modify_dim] *= 2
477
478 return [ifm_shapes[0], ifm_shapes[1]]
479
480 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000481 def tgRFFT2d(testGen, op, rank, error_name=None):
482 pl, const = op["operands"]
483
484 if error_name != ErrorIf.WrongRank:
485 assert rank == 3
486 assert pl == 1 and const == 0
487
488 # IFM dimensions are NHW
489 ifm_shape = testGen.makeShape(rank)
490
491 # Select nearest lower power of two from input height and width
492 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
493 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
494
495 # Constrict the overall size of the shape when creating ERROR_IF tests
496 if error_name:
497 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
498
499 # Generate an invalid kernel that is not a power of two
500 if error_name == ErrorIf.KernelNotPowerOfTwo:
501 # We must increment by 2 if current size is 1
502 inc_h = 2 if ifm_shape[1] == 1 else 1
503 inc_w = 2 if ifm_shape[2] == 1 else 1
504 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
505 selected_inc = testGen.rng.choice(inc_choices)
506 ifm_shape[1] += selected_inc[0]
507 ifm_shape[2] += selected_inc[1]
508
James Ward30124a82023-02-02 14:56:33 +0000509 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000510
511 return [ifm_shape]
512
513 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100514 def tgFullyConnected(testGen, op, rank, error_name=None):
515 pl, const = op["operands"]
516
517 if error_name != ErrorIf.WrongRank:
518 assert rank == 2
519
520 input_shape = testGen.makeShape(rank)
521
522 # Constrict the overall size of the shape when creating ERROR_IF tests
523 if error_name:
524 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
525
526 filter_oc = testGen.rng.integers(
527 low=testGen.args.tensor_shape_range[0],
528 high=testGen.args.tensor_shape_range[1],
529 size=1,
530 )[0]
531 filter_shape = np.asarray([filter_oc, input_shape[1]])
532
533 bias_shape = np.asarray([filter_oc])
534
535 return [input_shape, filter_shape, bias_shape]
536
537 @staticmethod
538 def tgMatmul(testGen, op, rank, error_name=None):
539 pl, const = op["operands"]
540
541 if error_name != ErrorIf.WrongRank:
542 assert rank == 3
543 assert pl == 2 and const == 0
544
545 a_shape = testGen.makeShape(rank)
546
547 # Constrict the overall size of the shape when creating ERROR_IF tests
548 if error_name:
549 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
550
551 # Get a random number for b_oc even if target shape is defined
552 b_oc = np.int32(
553 testGen.rng.integers(
554 low=testGen.args.tensor_shape_range[0],
555 high=testGen.args.tensor_shape_range[1],
556 size=1,
557 )
558 )[0]
559 # If N or H is large let b_oc be 1 to reduce output tensor size
560 if max(a_shape) > 1000:
561 b_oc = 1
562
563 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
564 return [a_shape, b_shape]
565
566 @staticmethod
567 def tgConcat(testGen, opName, rank, error_name=None):
568 pl, const = opName["operands"]
569 shape = testGen.makeShape(rank)
570
571 # Create extra tensors to concat.
572 # Take into account value of pl when getting maximum number of concats
573 num_tensors = testGen.randInt(0, 4)
574 shape_list = []
575 for i in range(pl + const + num_tensors):
576 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
577 remove = testGen.rng.choice([True, False])
578 wrongShape = shape.copy()
579
580 if remove and len(shape) > 1:
581 wrongShape = wrongShape[1:]
582 else:
583 wrongShape = list(wrongShape)
584 wrongShape.append(testGen.rng.integers(1, 10))
585
586 shape_list.append(wrongShape)
587 else:
588 shape_list.append(shape.copy())
589
590 return shape_list
591
592 @staticmethod
593 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
594 if error_name in [
595 ErrorIf.AxisSmallerZero,
596 ErrorIf.AxisLargerRank,
597 ErrorIf.ConcatInputRankMismatch,
598 ]:
599 return shapeList
600
601 # Split concat shape along axis to allow for multiple const inputs
602 # without making too many large tensors
603 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
604 # If axis can't be split we still need to invalidate other dimensions
605 if error_name == ErrorIf.ConcatInputDimMismatch:
606 for shape in shapeList[1:]:
607 # Negative test shapeLists are created individually for each test,
608 # so no need to copy the shape before altering it.
609 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
610 return shapeList
611
612 # Create copy of shape we are going to split (so we don't alter shapeList)
613 shape = shapeList[0].copy()
614 # Add original shape as first input
615 new_shapeList = [shape.copy()]
616 length_on_axis = shape[axis]
617 remaining_length = length_on_axis
618 for i in range(len(shapeList) - 2):
619 # Calculate split on axis and remaining value
620 split_shape_val = int(shape[axis] / 2)
621 remaining_length = remaining_length - split_shape_val
622
623 # Append new shape, and set remaining shape
624 shape[axis] = split_shape_val
625 new_shapeList.append(shape.copy())
626
627 # invalidate dimensions
628 if error_name == ErrorIf.ConcatInputDimMismatch:
629 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
630 else:
631 shape[axis] = remaining_length
632
633 if i == len(shapeList) - 3:
634 new_shapeList.append(shape.copy())
635
636 return new_shapeList
637
638
639class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100640 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100641
642 def __init__(self):
643 pass
644
Jeremy Johnson1271c442023-09-05 11:39:26 +0100645 class TVGInfo:
646 """Enhanced tensor values information including data gen dict."""
647
648 def __init__(self, tensorList, dataGenDict):
649 self.tensorList = tensorList
650 self.dataGenDict = dataGenDict
651
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100652 # Default high value for random numbers
653 TVG_FLOAT_HIGH_VALUE = {
654 DType.FP32: (1 << 128) - (1 << (127 - 23)),
655 DType.FP16: (1 << 16) - (1 << (15 - 10)),
656 DType.BF16: (1 << 128) - (1 << (127 - 7)),
Won Jeon2c34b462024-02-06 18:37:00 +0000657 DType.FP8E4M3: 448,
658 DType.FP8E5M2: 57344,
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100659 }
660
Jeremy Johnson30476252023-11-20 16:15:30 +0000661 # Default lowest normal values for random numbers
662 TVG_FLOAT_LOW_VALUE = {
663 DType.FP32: np.exp2(-126),
664 DType.FP16: np.exp2(-14),
665 DType.BF16: np.exp2(-126),
Won Jeon2c34b462024-02-06 18:37:00 +0000666 DType.FP8E4M3: np.exp2(-9),
667 DType.FP8E5M2: np.exp2(-16),
Jeremy Johnson30476252023-11-20 16:15:30 +0000668 }
669
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100670 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +0000671 def _get_data_range(testGen, dtype, highValueLookup, lowValueLookup=None):
672 # Return a tuple of (low,high) data range values for the given data
673 # type using a combination of per operator table limits, data limits
674 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000675 if dtype in highValueLookup:
Jeremy Johnson30476252023-11-20 16:15:30 +0000676 type_range = testGen.getDTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000677 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000678 if lowValueLookup is not None and dtype in lowValueLookup:
679 low_val = lowValueLookup[dtype]
680 else:
681 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000682 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000683 # respecting the default ranges if more/less than the low/high
684 # values
685 data_range = (
686 max(low_val, type_range[0]),
687 min(high_val, type_range[1]),
688 )
689 if data_range[0] > data_range[1]:
690 # Invalid data range from low to high created due to user
691 # constraints revert to using internal ranges as they are
692 # known to work
693 msg = f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
694 warnings.warn(msg)
695 data_range = (low_val, high_val)
696 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000697 return None
698
699 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100700 def tvgLazyGenDefault(
701 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
702 ):
703 # Variable inputs versus constants
704 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000705 if "p_count" in argsDict:
706 # Override for operators like CONCAT
707 pCount = argsDict["p_count"]
708 cCount = argsDict["c_count"]
709 assert pCount + cCount == len(
710 shapeList
711 ), "Placeholders & Constant tensors must match shapes list"
712
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000713 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100714
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100715 if (
716 error_name is not None
717 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100718 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100719 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000720 # Fall back to internal data gen when dealing with unsupported types or ops
721 data_range = argsDict["data_range"] if "data_range" in argsDict else None
722 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000723 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000724 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000725 if "data_range_list" in argsDict:
726 data_range = argsDict["data_range_list"][idx]["range"]
727 roundMode = (
728 "round" in argsDict["data_range_list"][idx]
729 and argsDict["data_range_list"][idx]["round"] is True
730 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000731 if data_range is not None and dtype not in (
732 DType.FP16,
733 DType.FP32,
734 DType.BF16,
Won Jeon2c34b462024-02-06 18:37:00 +0000735 DType.FP8E4M3,
736 DType.FP8E5M2,
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000737 ):
738 # Change from inclusive to exclusive range
739 data_range = (data_range[0], data_range[1] + 1)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000740 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000741
742 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
Jeremy Johnson0a042992024-02-28 13:20:05 +0000743 if dtype == DType.SHAPE:
744 arr = np.int64(argsDict["fixed_data"][idx])
745 elif dtype == DType.INT8:
746 arr = np.int8(argsDict["fixed_data"][idx])
747 else:
748 assert False, "Unsupported fixed_data type"
Won Jeon64e4bfe2024-01-18 06:31:55 +0000749 else:
750 arr = testGen.getRandTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000751 if roundMode:
752 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000753 if idx < pCount:
754 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
755 else:
756 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100757
Jeremy Johnson1271c442023-09-05 11:39:26 +0100758 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
759
760 # Create data generator meta-data
761 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100762 tens_data = {
763 "version": "0.1",
764 "tensors": {},
765 }
766 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100767 for idx, shape in enumerate(shapeList):
768
769 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000770 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
771 tens_meta["generator"] = gtu.DataGenType(
772 gtu.DataGenType.FIXED_DATA
773 ).name
774 else:
775 tens_meta["generator"] = gtu.DataGenType(dg_type).name
776
Jeremy Johnson1271c442023-09-05 11:39:26 +0100777 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
778 tens_meta["shape"] = [int(i) for i in shape]
779 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100780 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100781
782 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100783 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100784 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100785 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100786
787 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
788 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000789 if (
790 tens_meta["generator"]
791 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
792 ):
793 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
794 tens_meta["fixed_data_info"] = info
795 else:
796 # TODO - generate seed for this generator based on test
797 info["rng_seed"] = 42
Jeremy Johnson30476252023-11-20 16:15:30 +0000798
Won Jeon64e4bfe2024-01-18 06:31:55 +0000799 data_range = None
800 if "data_range_list" in argsDict:
801 data_range = argsDict["data_range_list"][idx]["range"]
802 if "round" in argsDict["data_range_list"][idx]:
803 info["round"] = argsDict["data_range_list"][idx]["round"]
804 elif "data_range" in argsDict:
805 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000806
Won Jeon64e4bfe2024-01-18 06:31:55 +0000807 if data_range is None:
808 data_range = testGen.getDTypeRange(
809 dtypeList[idx], high_inclusive=True
810 )
811 info["range"] = [str(v) for v in data_range]
812 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100813 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
814 info = {}
815 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100816 info["ks"] = int(argsDict["ks"])
817 if "acc_type" in argsDict:
818 # Convert type number into JSON name
819 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
820 "json"
821 ]
822 if "kernel" in argsDict:
823 info["kernel"] = [int(k) for k in argsDict["kernel"]]
824 if "axis" in argsDict:
825 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100826 tens_meta["dot_product_info"] = info
827 else:
828 # TODO - other data gen type
829 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100830
831 # Using the finished generate config meta data - generate the data if
832 # needed and assign a tensor name from the serializer
833
834 # Need to generate data when not lazy or for the bias tensor as we need
835 # to work out if the bias data is non-zero for compliance
836 if not testGen.args.lazy_data_gen or (
837 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
838 ):
839 # Give this tensor a temporary name until we get one from the serializer
840 temp_name = f"placeholder_{idx}"
841 dg_tens_meta[temp_name] = tens_meta
842 # Create data now using the temporary name to access meta details
843 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000844 if tens_meta["data_type"] == "SHAPE":
845 # Tensor type SHAPE and Numpy file type must be the same
846 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100847 # Remove the item as we will give it the correct name later
848 del dg_tens_meta[temp_name]
849
850 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
851 # The KS value used by compliance verification is altered when the
852 # bias data is non-zero
853 if max(abs(data)) > 0.0:
854 argsDict["ksb"] = argsDict["ks"] + 1
855
856 if testGen.args.lazy_data_gen:
857 data = None
858
859 if tens_meta["input_type"] == "VARIABLE":
860 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
861 else:
862 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
863
864 tens_ser_list.append(tens)
865 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100866 dg_tens_meta[tens.name] = tens_meta
867
Jeremy Johnson1271c442023-09-05 11:39:26 +0100868 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
869
870 @staticmethod
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000871 def tvgNegate(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100872 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000873 # Integer test
874 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100875 pCount, cCount = op["operands"]
876 assert (
877 pCount == 1 and cCount == 0
878 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100879 # Must create tensors with values within accumulator (int32) negatable
880 # range
881 max_val = (1 << 31) - 1
882 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100883 arr = np.int32(
884 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
885 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000886 tens_ser_list = []
887 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100888 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
889 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000890 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100891 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000892 # ERROR_IF or floating point test
893 return TosaTensorValuesGen.tvgLazyGenDefault(
894 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100895 )
896
Jeremy Johnson30476252023-11-20 16:15:30 +0000897 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000898 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
899 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
900 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
901 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
902 }
903
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100904 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000905 def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon74342e52024-01-09 00:34:40 +0000906 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000907 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100908 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000909 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100910 pCount, cCount = op["operands"]
911 assert (
912 pCount == 2 and cCount == 0
913 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000914 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000915 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
916 data_range = testGen.args.tensor_shape_range
917 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0], data_range)
918 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100919 if add:
920 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
921 else:
922 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
923
924 # Work out the saturation limits
925 max_i32 = (1 << 31) - 1
926 min_i32 = -(1 << 31)
927 max_arr = np.full(shapeList[1], max_i32)
928 min_arr = np.full(shapeList[1], min_i32)
929
930 # Find how much values exceed the maximum/minimums
931 sat_max_arr = np.maximum(res_arr - max_arr, 0)
932 sat_min_arr = np.minimum(res_arr - min_arr, 0)
933
934 if not add:
935 # Swap saturation values and negate values as we need to perform opposite operations
936 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
937
938 # Create new array of unsaturated values by clipping values as needed
939 b_unsat_arr = b_arr
940 if (sat_max_arr != 0).any():
941 # Clip values that cause saturation
942 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
943 # Reduce axes in unsaturated tensor to match original tensor
944 for axis, dim in enumerate(b_arr.shape):
945 if dim != b_unsat_arr.shape[axis]:
946 assert (
947 dim == 1
948 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
949 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
950
951 if (sat_min_arr != 0).any():
952 # Clip values that cause saturation
953 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
954 # Reduce axes in unsaturated tensor to match original tensor
955 for axis, dim in enumerate(b_arr.shape):
956 if dim != b_unsat_arr.shape[axis]:
957 assert (
958 dim == 1
959 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
960 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
961
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000962 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100963 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
964 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000965 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100966 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
967 )
968
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000969 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100970 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000971 # ERROR_IF or floating point test
972 data_range = TosaTensorValuesGen._get_data_range(
973 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
974 )
975 if data_range:
976 argsDict["data_range"] = data_range
977
978 return TosaTensorValuesGen.tvgLazyGenDefault(
979 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100980 )
981
982 @staticmethod
983 def tvgCondIfWhileLoop(
Jeremy Johnson587cc842024-02-08 11:45:44 +0000984 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100985 ):
986 if dtypeList[0] in (
987 DType.INT32,
988 DType.INT16,
989 DType.INT8,
990 ):
991 # Limit input tensors with cond_if_binary or while_loop to stop
992 # saturation of add/sub ops with int32 and keep all logical shift
993 # values between 0 to 31 for int16 or int8
Jeremy Johnson587cc842024-02-08 11:45:44 +0000994 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100995 pCount, cCount = op["operands"]
996 pRemain = pCount
Jeremy Johnson587cc842024-02-08 11:45:44 +0000997 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100998 for idx, shape in enumerate(shapeList[:]):
999 if dtypeList[0] == DType.INT32:
1000 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
1001 else:
1002 arr = np.int32(
1003 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
1004 )
1005 if pRemain > 0:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001006 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001007 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
1008 )
1009 pRemain -= 1
1010 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001011 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001012 testGen.ser.addConst(shape, dtypeList[idx], arr)
1013 )
1014
Jeremy Johnson587cc842024-02-08 11:45:44 +00001015 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001016 else:
Jeremy Johnson587cc842024-02-08 11:45:44 +00001017 return TosaTensorValuesGen.tvgLazyGenDefault(
1018 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001019 )
1020
1021 @staticmethod
1022 def tvgArithmeticRightShift(
Jeremy Johnson587cc842024-02-08 11:45:44 +00001023 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001024 ):
Jeremy Johnson587cc842024-02-08 11:45:44 +00001025 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001026 pCount, cCount = op["operands"]
1027 # Force value of operand[1] to be within [0, num_bits]
1028 assert (
1029 pCount == 2 and cCount == 0
1030 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1031
Jeremy Johnson587cc842024-02-08 11:45:44 +00001032 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001033 for idx, shape in enumerate(shapeList[:]):
1034 if idx == 1:
1035 if dtypeList[idx] == DType.INT8:
1036 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
1037 elif dtypeList[idx] == DType.INT16:
1038 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
1039 elif dtypeList[idx] == DType.INT32:
1040 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
1041 elif error_name == ErrorIf.WrongInputType:
1042 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
1043 else:
1044 raise Exception("OpArithmeticRightShift: invalid input dtype")
1045 else:
1046 arr = testGen.getRandTensor(shape, dtypeList[idx])
Jeremy Johnson587cc842024-02-08 11:45:44 +00001047 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001048
Jeremy Johnson587cc842024-02-08 11:45:44 +00001049 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001050
1051 @staticmethod
Jeremy Johnson587cc842024-02-08 11:45:44 +00001052 def tvgReshape(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001053 dtypeList[1] = DType.SHAPE
1054 shapeList[1] = [len(argsDict["new_shape"])]
1055 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1056 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1057
1058 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson587cc842024-02-08 11:45:44 +00001059 testGen, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001060 )
1061
1062 @staticmethod
Jeremy Johnson587cc842024-02-08 11:45:44 +00001063 def tvgPad(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Tai Lye095da72024-01-25 22:00:18 +00001064 # argsDict["pad"] is 2D array, need to flatten it to get list of values
1065 pad_values = argsDict["pad"].flatten()
1066 dtypeList[1] = DType.SHAPE
1067 shapeList[1] = [len(pad_values)]
1068 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1069 argsDict["fixed_data"] = [None, pad_values]
1070
1071 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson587cc842024-02-08 11:45:44 +00001072 testGen, opName, dtypeList, shapeList, argsDict, error_name
Tai Lye095da72024-01-25 22:00:18 +00001073 )
1074
1075 @staticmethod
Jeremy Johnson587cc842024-02-08 11:45:44 +00001076 def tvgSlice(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
TatWai Chongf15bad82024-01-31 21:33:27 -08001077 dtypeList[1] = DType.SHAPE
1078 shapeList[1] = [len(argsDict["start"])]
1079 dtypeList[2] = DType.SHAPE
1080 shapeList[2] = [len(argsDict["size"])]
1081 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1082 argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]]
1083
1084 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson587cc842024-02-08 11:45:44 +00001085 testGen, opName, dtypeList, shapeList, argsDict, error_name
TatWai Chongf15bad82024-01-31 21:33:27 -08001086 )
1087
1088 @staticmethod
Jeremy Johnson587cc842024-02-08 11:45:44 +00001089 def tvgTile(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon64e4bfe2024-01-18 06:31:55 +00001090 dtypeList[1] = DType.SHAPE
1091 shapeList[1] = [len(argsDict["multiples"])]
1092 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1093
1094 return TosaTensorValuesGen.tvgLazyGenDefault(
Jeremy Johnson587cc842024-02-08 11:45:44 +00001095 testGen, opName, dtypeList, shapeList, argsDict, error_name
Won Jeon64e4bfe2024-01-18 06:31:55 +00001096 )
1097
1098 @staticmethod
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001099 def tvgSelect(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001100 # Set datatype of condition tensor to boolean
1101 dtypeList[0] = DType.BOOL
1102
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001103 return TosaTensorValuesGen.tvgLazyGenDefault(
1104 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001105 )
1106
1107 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001108 def tvgIntDiv(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001109 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001110 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001111 pCount, cCount = op["operands"]
1112 assert (
1113 pCount == 2 and cCount == 0
1114 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1115
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001116 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001117
1118 # Two invalid cases for Op.INTDIV:
1119 # 1. divisor == 0
1120 # 2. dividend == -(1<<31) and divisor == -1
1121 while True:
1122 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1123 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
1124
1125 if (divisor_arr == 0).any():
1126 continue
1127
1128 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1129 continue
1130
1131 break
1132
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001133 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001134 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1135 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001136 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001137 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1138 )
1139
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001140 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001141 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001142 return TosaTensorValuesGen.tvgLazyGenDefault(
1143 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001144 )
1145
Jeremy Johnson30476252023-11-20 16:15:30 +00001146 # Set the MUL data range to the square root of the largest value
1147 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001148 TVG_FLOAT_HIGH_VALUE_MUL = {
1149 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1150 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1151 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1152 }
1153
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001154 @staticmethod
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001155 def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1156 if error_name is not None or dtypeList[0] in (
1157 DType.FP16,
1158 DType.BF16,
1159 DType.FP32,
1160 ):
1161 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001162 data_range = TosaTensorValuesGen._get_data_range(
1163 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
1164 )
1165 if data_range:
1166 argsDict["data_range"] = data_range
1167
Jeremy Johnson0a042992024-02-28 13:20:05 +00001168 if dtypeList[0] != DType.SHAPE:
1169 # Need to supply shift tensor for MUL (not needed for MUL_SHAPE)
1170 dtypeList[2] = DType.INT8
1171 shapeList[2] = [1]
1172 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1173 argsDict["fixed_data"] = [None, None, [argsDict["shift"]]]
1174
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001175 return TosaTensorValuesGen.tvgLazyGenDefault(
1176 testGen, opName, dtypeList, shapeList, argsDict, error_name
1177 )
1178 else:
1179 # Integer test
1180 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001181 pCount, cCount = op["operands"]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001182
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001183 tens_ser_list = []
1184
1185 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001186 if dtypeList[0] == DType.SHAPE:
1187 shift = 0
1188 else:
1189 shift = argsDict["shift"]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001190 if dtypeList[0] == DType.INT8:
1191 num_bits = 8
1192 elif dtypeList[0] == DType.INT16:
1193 num_bits = 16
Won Jeon74342e52024-01-09 00:34:40 +00001194 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001195 num_bits = 32
1196 elif error_name == ErrorIf.WrongInputType:
1197 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001198 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001199 raise Exception("OpMul: invalid input dtype")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001200
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001201 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001202 if dtypeList[idx] == DType.SHAPE:
1203 low = testGen.args.tensor_shape_range[0]
1204 high = testGen.args.tensor_shape_range[1]
1205 else:
1206 low = -(2 ** (num_bits - 1))
1207 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001208
1209 a_arr = np.int32(
1210 testGen.rng.integers(low=low, high=high, size=shapeList[0])
1211 )
1212 b_arr = np.int32(
1213 testGen.rng.integers(low=low, high=high, size=shapeList[1])
1214 )
1215
1216 i = 0
1217 while True:
1218
1219 a_arr_64 = a_arr.astype(np.int64)
1220 b_arr_64 = b_arr.astype(np.int64)
1221
1222 if shift > 0:
1223 rounding = 1 << (shift - 1)
1224 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001225 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001226 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001227
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001228 if (result_arr > -(2**31)).all() and (
1229 result_arr <= ((2**31) - 1)
1230 ).all():
1231 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001232
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001233 i = i + 1
1234 a_arr = a_arr // 2
1235 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001236
Won Jeon74342e52024-01-09 00:34:40 +00001237 if dtypeList[0] == DType.SHAPE:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001238 # MUL_SHAPE with 2 inputs
Won Jeon74342e52024-01-09 00:34:40 +00001239 tens_ser_list.append(
1240 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1241 )
1242 tens_ser_list.append(
1243 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1244 )
1245 else:
Jeremy Johnson0a042992024-02-28 13:20:05 +00001246 # MUL with 3 inputs (3rd is shift)
Won Jeon74342e52024-01-09 00:34:40 +00001247 tens_ser_list.append(
1248 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1249 )
1250 tens_ser_list.append(
1251 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1252 )
Jeremy Johnson0a042992024-02-28 13:20:05 +00001253 tens_ser_list.append(
1254 testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift]))
1255 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001256
1257 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001258
1259 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001260 def tvgConcat(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001261 count = len(shapeList) - testGen.args.num_const_inputs_concat
1262 if count < 1:
1263 count = 1
1264 if testGen.args.num_const_inputs_concat == 0:
1265 count = len(shapeList)
1266
Won Jeon74342e52024-01-09 00:34:40 +00001267 op = testGen.TOSA_OP_LIST[opName]
1268 if op["op"] == Op.CONCAT_SHAPE:
1269 # Set the axis to 0
1270 shapeList = TosaTensorGen.tgConcatConstInput(
1271 testGen, shapeList, 0, error_name
1272 )
1273 else:
1274 shapeList = TosaTensorGen.tgConcatConstInput(
1275 testGen, shapeList, argsDict["axis"], error_name
1276 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001277
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001278 # Override default pCount/cCount for operator
1279 argsDict["p_count"] = count
1280 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001281
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001282 return TosaTensorValuesGen.tvgLazyGenDefault(
1283 testGen, opName, dtypeList, shapeList, argsDict, error_name
1284 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001285
1286 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001287 def tvgLogicalShift(
1288 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1289 ):
1290 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001291 pCount, cCount = op["operands"]
1292 assert (
1293 pCount == 2 and cCount == 0
1294 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
1295 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1296 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001297 tens_ser_list = []
1298 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001299 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1300 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001301 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001302 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1303 )
1304
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001305 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001306
1307 @staticmethod
Jeremy Johnsona0150012023-11-15 15:52:06 +00001308 def tvgEqual(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1309 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1310 # Integer
1311 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001312 pCount, cCount = op["operands"]
1313 assert (
1314 pCount == 2 and cCount == 0
1315 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001316
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001317 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1318 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001319
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001320 # Using random numbers means that it will be very unlikely that
1321 # there are any matching (equal) values, therefore force that
1322 # there are twice the number of matching values as the tensor rank
1323 for num in range(0, len(shapeList[0]) * 2):
1324 a_index = []
1325 b_index = []
1326 # Choose an index in each axis for the whole shape
1327 for axis in range(0, len(shapeList[0])):
1328 # Index can be up to the largest dimension in both shapes
1329 index = np.int32(
1330 testGen.rng.integers(
1331 0, max(shapeList[0][axis], shapeList[1][axis])
1332 )
1333 )
1334 # Reduce the index down to a shape's dim for broadcasting
1335 a_index.append(min(shapeList[0][axis] - 1, index))
1336 b_index.append(min(shapeList[1][axis] - 1, index))
1337
1338 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1339
Jeremy Johnsona0150012023-11-15 15:52:06 +00001340 tens_ser_list = []
1341 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001342 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1343 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001344 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001345 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1346 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001347 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001348 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001349 # ERROR_IF or floating point test
1350 return TosaTensorValuesGen.tvgLazyGenDefault(
1351 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001352 )
1353
1354 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001355 def tvgReduceSum(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001356 dtype = dtypeList[0]
1357 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001358 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001359 pCount, cCount = op["operands"]
1360 assert (
1361 pCount == 1 and cCount == 0
1362 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1363 # Limit values so that the sum cannot exceed the range of an int32 during
1364 # summation of any axis
1365 range_val = int((1 << 31) / max(shapeList[0]))
1366 values_arr = np.int32(
1367 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1368 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001369 tens_ser_list = []
1370 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001371 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001372 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001373 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001374 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001375 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001376 if (
1377 error_name is None
1378 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1379 ):
1380 # Limit ranges for (non error & non compliance) tests by using
1381 # values that can be summed on any axis to not hit infinity
1382 highval_lookup = {
1383 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1384 / max(shapeList[0])
1385 }
1386 data_range = TosaTensorValuesGen._get_data_range(
1387 testGen, dtype, highval_lookup
1388 )
1389 assert data_range is not None
1390 argsDict["data_range"] = data_range
1391
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001392 return TosaTensorValuesGen.tvgLazyGenDefault(
1393 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001394 )
1395
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001396 @staticmethod
1397 def tvgReduceProduct(
1398 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1399 ):
1400 dtype = dtypeList[0]
1401 if error_name is None:
1402 # Limit ranges for (non error) tests by using
1403 # values that can be multiplied on any axis to not hit infinity
1404 highval_lookup = {
1405 dtype: math.pow(
1406 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1407 1 / max(shapeList[0]),
1408 )
1409 }
1410 data_range = TosaTensorValuesGen._get_data_range(
1411 testGen, dtype, highval_lookup
1412 )
1413 assert data_range is not None
1414 argsDict["data_range"] = data_range
1415
1416 return TosaTensorValuesGen.tvgLazyGenDefault(
1417 testGen, opName, dtypeList, shapeList, argsDict, error_name
1418 )
1419
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001420 @staticmethod
1421 def tvgResize(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1422 data_range = TosaTensorValuesGen._get_data_range(
1423 testGen,
1424 dtypeList[0],
1425 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1426 )
1427 if data_range:
1428 argsDict["data_range"] = data_range
1429 # Needed for compliance
1430 argsDict["max_abs_value"] = data_range[1]
1431
Tai Lyc5c2a7e2024-02-22 23:26:28 +00001432 scale_values = argsDict["scale"]
1433 offset_values = argsDict["offset"]
1434 border_values = argsDict["border"]
1435 dtypeList[1] = DType.SHAPE
1436 dtypeList[2] = DType.SHAPE
1437 dtypeList[3] = DType.SHAPE
1438 shapeList[1] = [len(scale_values)]
1439 shapeList[2] = [len(offset_values)]
1440 shapeList[3] = [len(border_values)]
1441 argsDict["fixed_data"] = [None, scale_values, offset_values, border_values]
1442
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001443 return TosaTensorValuesGen.tvgLazyGenDefault(
1444 testGen, opName, dtypeList, shapeList, argsDict, error_name
1445 )
1446
Jeremy Johnson30476252023-11-20 16:15:30 +00001447 # Set the POW exponent high data range
1448 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1449 DType.FP32: 10.0,
1450 DType.FP16: 10.0,
1451 DType.BF16: 10.0,
1452 }
1453 # POW highest base value (within a safe margin of error) that can be raised
1454 # to +ve exponent that doesn't become Infinity
1455 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1456 DType.FP32: math.floor(
1457 math.pow(
1458 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1459 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1460 )
1461 ),
1462 DType.FP16: math.floor(
1463 math.pow(
1464 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1465 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1466 )
1467 ),
1468 DType.BF16: math.floor(
1469 math.pow(
1470 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1471 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1472 )
1473 ),
1474 }
1475 # POW lowest base value (within a safe margin of error) that can be raised
1476 # to -ve exponent that doesn't become Infinity
1477 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1478 DType.FP32: math.ceil(
1479 math.pow(
1480 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1481 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1482 )
1483 * 1000
1484 )
1485 / 1000,
1486 DType.FP16: math.ceil(
1487 math.pow(
1488 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1489 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1490 )
1491 * 1000
1492 )
1493 / 1000,
1494 DType.BF16: math.ceil(
1495 math.pow(
1496 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1497 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1498 )
1499 * 1000
1500 )
1501 / 1000,
1502 }
1503
1504 @staticmethod
1505 def tvgPow(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1506 if error_name is not None:
1507 return TosaTensorValuesGen.tvgLazyGenDefault(
1508 testGen, opName, dtypeList, shapeList, argsDict, error_name
1509 )
1510 dtype = dtypeList[0]
1511 # Different ranges for POW
1512 test_set = argsDict["s"]
1513 if test_set == 0:
1514 # Positive base with fractional exponent
1515 base_range = TosaTensorValuesGen._get_data_range(
1516 testGen,
1517 dtype,
1518 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1519 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1520 )
1521 exp_range = TosaTensorValuesGen._get_data_range(
1522 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1523 )
1524 exp_round = False
1525 else:
1526 # Integer exponent
1527 exp_range = TosaTensorValuesGen._get_data_range(
1528 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1529 )
1530 exp_round = True
1531 if test_set == 1:
1532 # Positive base
1533 base_range = TosaTensorValuesGen._get_data_range(
1534 testGen,
1535 dtype,
1536 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1537 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1538 )
1539 else:
1540 assert test_set == 2
1541 # Negative base
1542 # Supply new look up tables with negative values
1543 base_range = TosaTensorValuesGen._get_data_range(
1544 testGen,
1545 dtype,
1546 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1547 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1548 )
1549
1550 data_range_list = (
1551 {
1552 "range": base_range,
1553 },
1554 {
1555 "range": exp_range,
1556 "round": exp_round,
1557 },
1558 )
1559 argsDict["data_range_list"] = data_range_list
1560 return TosaTensorValuesGen.tvgLazyGenDefault(
1561 testGen, opName, dtypeList, shapeList, argsDict, error_name
1562 )
1563
1564 @staticmethod
1565 def tvgLogRsqrt(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1566 # LOG & RSQRT data range from lowest expressible positive number to
1567 # largest to avoid NaNs
1568 data_range = TosaTensorValuesGen._get_data_range(
1569 testGen,
1570 dtypeList[0],
1571 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1572 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1573 )
1574 if data_range:
1575 argsDict["data_range"] = data_range
1576
1577 return TosaTensorValuesGen.tvgLazyGenDefault(
1578 testGen, opName, dtypeList, shapeList, argsDict, error_name
1579 )
1580
1581 # Set the EXP data range to the log of the largest to smallest values
1582 # to avoid infinities or making the result zero
1583 TVG_FLOAT_HIGH_VALUE_EXP = {
1584 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1585 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1586 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1587 }
1588 TVG_FLOAT_LOW_VALUE_EXP = {
1589 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1590 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1591 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1592 }
1593
1594 @staticmethod
1595 def tvgExp(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1596 data_range = TosaTensorValuesGen._get_data_range(
1597 testGen,
1598 dtypeList[0],
1599 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1600 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1601 )
1602 if data_range:
1603 argsDict["data_range"] = data_range
1604
1605 return TosaTensorValuesGen.tvgLazyGenDefault(
1606 testGen, opName, dtypeList, shapeList, argsDict, error_name
1607 )
1608
1609 @staticmethod
1610 def tvgFullyConnected(
1611 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1612 ):
1613 dtype = dtypeList[0]
1614 if (
1615 error_name is None
1616 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001617 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001618 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001619 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001620 # Limit ranges for (non error & non compliance) FP tests by using
1621 # values that can be multiplied on any axis to not hit infinity/NaN
1622 IC = shapeList[0][1]
1623 highval_lookup = {
1624 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1625 }
1626 data_range = TosaTensorValuesGen._get_data_range(
1627 testGen, dtype, highval_lookup
1628 )
1629 assert data_range is not None
1630 argsDict["data_range"] = data_range
1631
1632 return TosaTensorValuesGen.tvgLazyGenDefault(
1633 testGen, opName, dtypeList, shapeList, argsDict, error_name
1634 )
1635
Jeremy Johnson708da822023-11-15 16:25:45 +00001636 @staticmethod
1637 def tvgCast(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1638 in_dtype = dtypeList[0]
1639 out_dtype = argsDict["out_type"]
1640 # Create look up to limit input tensor to output type maximums to avoid
1641 # FP infinities and saturation of integers
1642 out_range = testGen.getDTypeRange(out_dtype, high_inclusive=True)
1643 highval_lookup = {in_dtype: out_range[1]}
1644 data_range = TosaTensorValuesGen._get_data_range(
1645 testGen,
1646 in_dtype,
1647 highval_lookup,
1648 )
1649
1650 assert data_range is not None
1651 argsDict["data_range"] = data_range
1652
1653 return TosaTensorValuesGen.tvgLazyGenDefault(
1654 testGen, opName, dtypeList, shapeList, argsDict, error_name
1655 )
1656
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001657 @staticmethod
1658 def tvgGather(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1659 K = shapeList[0][1]
1660
1661 # Fix the type of the indices tensor
1662 dtypeList[1] = DType.INT32
1663
1664 dtype = dtypeList[0]
1665 if not gtu.dtypeIsSupportedByCompliance(dtype):
1666 # Test unsupported by data generator
1667 op = testGen.TOSA_OP_LIST[opName]
1668 pCount, cCount = op["operands"]
1669 assert (
1670 pCount == 2 and cCount == 0
1671 ), "Op.GATHER must have 2 placeholders, 0 consts"
1672
1673 tens_ser_list = []
1674 for idx, shape in enumerate(shapeList):
1675 dtype = dtypeList[idx]
1676 if idx != 1:
1677 arr = testGen.getRandTensor(shape, dtype)
1678 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1679 else:
1680 # Limit data range of indices tensor upto K (exclusive)
1681 arr = testGen.getRandTensor(shape, dtype, (0, K))
1682 # To match old functionality - create indices as CONST
1683 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1684
1685 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1686
1687 else:
1688 # ERROR_IF or floating point test
1689 # Use inclusive values upto index K for indices tensor
1690 data_range_list = (
1691 {"range": None},
1692 {"range": (0, K - 1)},
1693 )
1694 argsDict["data_range_list"] = data_range_list
1695
1696 return TosaTensorValuesGen.tvgLazyGenDefault(
1697 testGen, opName, dtypeList, shapeList, argsDict, error_name
1698 )
1699
1700 @staticmethod
1701 def tvgScatter(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1702 K = shapeList[0][1]
1703 W = shapeList[2][1]
1704
1705 # Work out an indices tensor here with data that doesn't exceed the
1706 # dimension K of the values_in tensor and does NOT repeat the same K
1707 # location as needed by the spec:
1708 # "It is not permitted to repeat the same output index within a single
1709 # SCATTER operation and so each output index occurs at most once."
1710 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1711
1712 # Fix the type of the indices tensor
1713 dtypeList[1] = DType.INT32
1714
1715 dtype = dtypeList[0]
1716 if not gtu.dtypeIsSupportedByCompliance(dtype):
1717 # Test unsupported by data generator
1718 op = testGen.TOSA_OP_LIST[opName]
1719 pCount, cCount = op["operands"]
1720 assert (
1721 pCount == 3 and cCount == 0
1722 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1723
1724 tens_ser_list = []
1725 for idx, shape in enumerate(shapeList):
1726 dtype = dtypeList[idx]
1727 if idx != 1:
1728 arr = testGen.getRandTensor(shape, dtype)
1729 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1730 else:
1731 # Create the indices array
1732 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1733 arr = []
1734 for n in range(shape[0]):
1735 # Get a shuffled list of output indices (0 to K-1) and
1736 # limit length to W
1737 arr.append(testGen.rng.permutation(K)[:W])
1738 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1739 # To match old functionality - create indices as CONST
1740 tens_ser_list.append(
1741 testGen.ser.addConst(shape, dtype, indices_arr)
1742 )
1743
1744 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1745
1746 else:
1747 # ERROR_IF or floating point test
1748 # Use inclusive values upto index K for indices tensor
1749 data_range_list = (
1750 {"range": None},
1751 {"range": (0, K - 1)},
1752 {"range": None},
1753 )
1754 argsDict["data_range_list"] = data_range_list
1755
1756 return TosaTensorValuesGen.tvgLazyGenDefault(
1757 testGen, opName, dtypeList, shapeList, argsDict, error_name
1758 )
1759
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001760
1761class TosaArgGen:
1762 """Argument generators create exhaustive or random lists of attributes for
1763 operators that take attributes or other parameters.
1764
1765 The return value is a list of (descriptive_name, [arglist]) tuples where
1766 the descriptive_name is appended to the test name and the arglist is expanded
1767 as arguments to the operator build function.
1768 """
1769
1770 def __init__(self):
1771 pass
1772
1773 @staticmethod
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001774 def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001775 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001776 if (
1777 error_name is None
1778 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1779 and gtu.dtypeIsSupportedByCompliance(dtype)
1780 ):
Won Jeon2c34b462024-02-06 18:37:00 +00001781 if dtype in [
1782 DType.FP16,
1783 DType.FP32,
1784 DType.BF16,
1785 DType.FP8E4M3,
1786 DType.FP8E5M2,
1787 ]:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001788 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1789 else:
1790 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1791 else:
1792 # Error test or No data generator types listed - assume random
1793 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1794
1795 # Expand arg list with other data generator types
1796 new_arg_list = []
1797 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001798 for arg_str, args_dict in arg_list:
1799 args_dict["dg_type"] = dg_type
Jeremy Johnson1271c442023-09-05 11:39:26 +01001800 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001801 if error_name is None:
1802 num_test_sets = (
1803 args_dict["num_test_sets"]
1804 if "num_test_sets" in args_dict
1805 else 0
1806 )
1807 else:
1808 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001809
1810 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1811 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001812 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001813 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001814 shape_info = (
1815 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1816 if "shape" in args_dict
1817 else ""
1818 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001819 print(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001820 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001821 )
1822 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001823 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001824 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001825 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001826
Jeremy Johnson30476252023-11-20 16:15:30 +00001827 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1828
1829 if num_test_sets > 0:
1830 for s in range(0, num_test_sets):
1831 new_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001832 new_args_dict = args_dict.copy()
1833 new_args_dict["s"] = s
1834 new_arg_list.append((new_arg_str, new_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001835 else:
1836 # Default is a single test
1837 new_arg_list.append((arg_str, args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001838
1839 return new_arg_list
1840
1841 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001842 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1843 """A trivial argument generator for operators that don't take any
1844 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001845 arg_list = TosaArgGen._add_data_generators(
1846 testGen,
1847 opName,
1848 dtype,
1849 [("", {})],
1850 error_name,
1851 )
1852 # Return list of tuples: (arg_str, args_dict)
1853 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001854
1855 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +00001856 def agPow(testGen, opName, shapeList, dtype, error_name=None):
1857 """Pow operator needs different test sets to cover random numbers
1858 without creating NaNs or Infs"""
1859 arg_list = TosaArgGen._add_data_generators(
1860 testGen,
1861 opName,
1862 dtype,
1863 [("", {"num_test_sets": 3})],
1864 error_name,
1865 )
1866 # Return list of tuples: (arg_str, args_dict)
1867 return arg_list
1868
1869 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001870 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1871 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001872 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001873 shape = shapeList[0]
1874
1875 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001876 # Set too small axis
1877 axes = [testGen.rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001878 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001879 # Set too large axis
1880 axes = [testGen.rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001881 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001882 # Create tests for each dimension
1883 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001884
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001885 opid = testGen.TOSA_OP_LIST[opName]["op"]
1886
1887 for a in axes:
1888 args_dict = {"axis": int(a)}
1889 if opid == Op.REDUCE_SUM:
1890 args_dict["dot_products"] = gtu.product(shape)
1891 args_dict["shape"] = shape
1892 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1893 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1894
1895 arg_list.append(("axis{}".format(a), args_dict))
1896
1897 arg_list = TosaArgGen._add_data_generators(
1898 testGen,
1899 opName,
1900 dtype,
1901 arg_list,
1902 error_name,
1903 )
1904 # Return list of tuples: (arg_str, args_dict)
1905 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001906
1907 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001908 def _calculate_sparsity(num_tests, sparsity_factor):
1909 sparsity = num_tests // sparsity_factor + 1
1910 # If there are only a small number of tests, just select them all
1911 if sparsity < 13:
1912 sparsity = 1
1913 # To get a variety of parameter combinations sparsity should not be a
1914 # multiple of 2, 3 or 5
1915 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1916 sparsity += 1
1917 return sparsity
1918
1919 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001920 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001921 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001922 arg_list = []
1923
Jeremy Johnson0c716862023-04-13 17:18:19 +01001924 if testGen.args.level8k and error_name is not None:
1925 # Don't produce negative large tests
1926 return arg_list
1927
1928 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001929 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001930 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001931 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001932
Jeremy Johnson1271c442023-09-05 11:39:26 +01001933 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001934
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001935 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01001936 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001937 depthwise = opName.startswith("depthwise")
1938
1939 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001940 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001941 if error_name != ErrorIf.WrongRank:
1942 assert len(ifm_shape) == rank
1943 assert len(filter_shape) == rank
1944
Jeremy Johnson0c716862023-04-13 17:18:19 +01001945 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001946 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001947 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01001948 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001949 # compliance size - KS
1950 k_size = gtu.product(k_shape)
1951 if not depthwise:
1952 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001953
Jeremy Johnson0c716862023-04-13 17:18:19 +01001954 if not testGen.args.level8k:
1955 # Generate comprehensive argument lists
1956 # - except for named errors, which use specific invalid value(s)
1957 if error_name == ErrorIf.PadSmallerZero:
1958 p_vals = [testGen.rng.choice(range(-5, 0))]
1959 else:
1960 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1961 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1962 if error_name == ErrorIf.StrideSmallerOne:
1963 # Can't use stride=0, as it is used to derive output shape, as a divisor
1964 s_vals = [testGen.rng.choice(range(-5, 0))]
1965 else:
1966 # Stride must be greater than 1 to force non-integer error
1967 startStride = (
1968 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001969 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001970 s_vals = [
1971 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1972 ]
1973 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1974 if error_name == ErrorIf.DilationSmallerOne:
1975 d_vals = [testGen.rng.choice(range(-5, 1))]
1976 else:
1977 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1978 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001979
Jeremy Johnson0c716862023-04-13 17:18:19 +01001980 if not error_name and testGen.args.oversize:
1981 # add some oversize argument values
1982 if max(ifm_shape) < 64:
1983 bigPadding = 9
1984 paddings.update(
1985 {
1986 x
1987 for x in itertools.product(
1988 *([[0, bigPadding]] * (k_rank * 2))
1989 )
1990 }
1991 )
1992 bigStride = 8
1993 strides.update(
1994 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1995 )
1996 bigDilation = 7
1997 dilations.update(
1998 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1999 )
2000 max_dim_size = None
2001
2002 # There are too many parameter combinations, so generate them sparsely,
2003 # very sparse for negative tests
2004 sparsity_factor = 2 if error_name else 120
2005 sparsity = TosaArgGen._calculate_sparsity(
2006 len(paddings) * len(strides) * len(dilations), sparsity_factor
2007 )
2008 else:
2009 # Only test 8k levels boundaries
2010 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2011 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2012 bigPadding = bigKernel
2013
2014 dilation_shape = [1] * k_rank
2015 pad_shape = [0] * k_rank * 2
2016 if conv3d:
2017 # Small stride apart from for big kernel (see below) to keep
2018 # tensor size/calculation small
2019 stride_shape = [1] * k_rank
2020 for idx in range(k_rank):
2021 pad_offset = idx * 2
2022 if k_shape[idx] == bigKernel:
2023 # Padding shape needs to account for tensor shape
2024 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2025 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2026 # Big stride to reduce output size
2027 stride_shape[idx] = bigKernel
2028 else:
2029 # Account for kernel size
2030 pad_shape[pad_offset] = k_shape[idx] - 1
2031 else:
2032 # Always have a large stride with extra padding and dilation to keep
2033 # tensor calculation reasonable
2034 stride_shape = [bigKernel] * k_rank
2035 for idx in range(k_rank):
2036 # Dilation shape must account for kernel size
2037 dilation_shape[idx] = bigKernel // k_shape[idx]
2038 # Padding shape needs to accommodate tensor/kernel & dilation
2039 pad_offset = idx * 2
2040 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2041 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2042
2043 strides = {tuple(stride_shape)}
2044 dilations = {tuple(dilation_shape)}
2045 paddings = {tuple(pad_shape)}
2046 # Create a limit for the output dimensions size
2047 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2048
2049 # Currently allow all combinations that are reasonable size
2050 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002051
2052 n = 0
2053 for s in sorted(list(strides)):
2054 for p in sorted(list(paddings)):
2055 for d in sorted(list(dilations)):
2056 if (
2057 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01002058 # the padded shape must exceed the dilation * kernel to get a positive
2059 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002060 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
2061 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01002062 and (
2063 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01002064 or (
2065 (ifm_shape[3] - 1 + p[4] + p[5])
2066 > d[2] * (k_shape[2] - 1)
2067 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01002068 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002069 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002070 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01002071 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002072 for index in range(k_rank):
2073 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002074 partial = (
2075 ifm_shape[index + 1]
2076 - 1
2077 + p[pad_offset]
2078 + p[pad_offset + 1]
2079 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002080 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002081 remainders.append(partial % s[index])
2082 outputs.append((partial // s[index]) + 1)
2083
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002084 if (
2085 # the parameters must produce integer exact output
2086 error_name != ErrorIf.ConvOutputShapeNonInteger
2087 and max(remainders) == 0
2088 ) or (
2089 error_name == ErrorIf.ConvOutputShapeNonInteger
2090 and max(remainders) > 0
2091 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002092 if (
2093 max_dim_size is not None
2094 and max(outputs) >= max_dim_size
2095 ):
2096 # Test will consume too much memory - skip it
2097 continue
2098
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002099 # Compliance - number of dot product calculations
2100 if depthwise:
Jeremy Johnson4f931302024-01-04 17:05:24 +00002101 # N*OH*OW*C*M
2102 dots = gtu.product(
2103 (ifm_shape[0], *outputs, *filter_shape[2:])
2104 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002105 else:
Jeremy Johnson4f931302024-01-04 17:05:24 +00002106 # N*OH*OW*OC or N*OD*OH*OW*OC
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002107 dots = gtu.product(
2108 (ifm_shape[0], *outputs, filter_shape[0])
2109 )
2110 args_dict = {
2111 "acc_type": accum_dtype,
2112 "stride": s,
2113 "pad": p,
2114 "dilation": d,
2115 "kernel": k_shape,
2116 "ks": k_size,
2117 "dot_products": dots,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002118 "shape": ifm_shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002119 }
2120
Jeremy Johnson0c716862023-04-13 17:18:19 +01002121 # Support for larger values than 9 needs different delimiter
2122 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002123 arg_list.append(
2124 (
James Ward8b390432022-08-12 20:48:56 +01002125 "acc{}_st{}_pad{}_dilat{}".format(
2126 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002127 delim.join([str(x) for x in s]),
2128 delim.join([str(x) for x in p]),
2129 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002130 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002131 args_dict,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002132 )
2133 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002134 n += 1
2135
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002136 arg_list = TosaArgGen._add_data_generators(
2137 testGen,
2138 opName,
2139 dtypes[0],
2140 arg_list,
2141 error_name,
2142 )
2143 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002144 return arg_list
2145
2146 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01002147 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
2148
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002149 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002150 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002151
2152 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002153 accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002154 elif error_name == ErrorIf.WrongInputType:
2155 # Pick some potentially correct output dtype if input type is incorrect
2156 accum_dtype = DType.INT32
2157 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002158 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002159
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002160 # Set up compliance info
2161 args_dict = {
2162 "acc_type": accum_dtype,
2163 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2164 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2165 "shape": shapeList[0],
2166 }
2167
2168 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2169
2170 arg_list = TosaArgGen._add_data_generators(
2171 testGen,
2172 opName,
2173 input_dtype,
2174 arg_list,
2175 error_name,
2176 )
2177 # Return list of tuples: (arg_str, args_dict)
2178 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002179
2180 @staticmethod
2181 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
2182 # Get valid accumulate type(s)
2183 if dtype == DType.INT8:
2184 accum_dtypes = [DType.INT32]
2185 elif dtype == DType.INT16:
2186 accum_dtypes = [DType.INT48]
2187 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002188 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002189 elif dtype == DType.BF16:
2190 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002191 elif dtype == DType.FP32:
2192 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002193 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2194 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002195 elif error_name is None:
2196 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2197
2198 if error_name == ErrorIf.WrongOutputType:
2199 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson1271c442023-09-05 11:39:26 +01002200 accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002201 elif error_name == ErrorIf.WrongInputType:
2202 # Pick some potentially correct output dtype if input type is incorrect
2203 accum_dtypes = [DType.INT32]
2204
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002205 # Set up compliance info
2206 args_dict = {
2207 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2208 # Set dot_products = N*H*W
2209 "dot_products": gtu.product(
2210 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2211 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002212 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002213 }
2214
2215 # Create arg tuple of string and dict
2216 arg_list = []
2217 for a in accum_dtypes:
2218 d = args_dict.copy()
2219 d["acc_type"] = a
2220 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002221
2222 arg_list = TosaArgGen._add_data_generators(
2223 testGen,
2224 opName,
2225 dtype,
2226 arg_list,
2227 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002228 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002229 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002230 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002231
2232 @staticmethod
2233 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002234 arg_list = []
2235
Jeremy Johnson0c716862023-04-13 17:18:19 +01002236 if testGen.args.level8k and error_name is not None:
2237 # Don't produce negative large tests
2238 return arg_list
2239
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002240 ifm_shape = shapeList[0]
2241 filter_shape = shapeList[1]
2242
Jeremy Johnson1271c442023-09-05 11:39:26 +01002243 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002244
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002245 # Must be rank 4
2246 if error_name != ErrorIf.WrongRank:
2247 assert len(ifm_shape) == 4
2248 assert len(filter_shape) == 4
2249
Jeremy Johnson0c716862023-04-13 17:18:19 +01002250 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson95a67102024-01-10 14:16:39 +00002251 # compliance size - KS
2252 k_size = gtu.product((*k_shape, ifm_shape[3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002253
Jeremy Johnson0c716862023-04-13 17:18:19 +01002254 if not testGen.args.level8k:
2255 # Generate comprehensive argument lists
2256 # - except for named errors, which use specific invalid value(s)
2257 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2258 if error_name == ErrorIf.PadLargerEqualKernel:
2259 max_filter_size = -max(k_shape[0], k_shape[1])
2260 p_vals = [
2261 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
2262 ]
2263 else:
2264 p_vals = [
2265 x
2266 for x in range(
2267 smallest_padding_size, testGen.args.max_conv_padding + 1
2268 )
2269 ]
2270 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2271 if error_name == ErrorIf.StrideSmallerOne:
2272 # Can't use stride=0, as it is used to derive output shape, as a divisor
2273 s_vals = [testGen.rng.choice(range(-5, 0))]
2274 else:
2275 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2276 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002277
Jeremy Johnson0c716862023-04-13 17:18:19 +01002278 if not error_name and testGen.args.oversize:
2279 # add some oversize argument values
2280 if max(ifm_shape) < 64:
2281 bigPadding = 9
2282 paddings.update(
2283 {
2284 x
2285 for x in itertools.product(
2286 *([[smallest_padding_size, bigPadding]] * 4)
2287 )
2288 }
2289 )
2290 bigStride = 8
2291 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2292
2293 # There are too many parameter combinations, so generate them sparsely,
2294 # very sparse for negative tests
2295 sparsity_factor = 2 if error_name else 10
2296 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2297 # If there are only a small number of tests, just select them all
2298 if sparsity < 13:
2299 sparsity = 1
2300 # To get a variety of parameter combinations sparsity should not be a
2301 # multiple of 2, 3 or 5
2302 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2303 sparsity += 1
2304 else:
2305 # Only test 8k levels boundaries
2306 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2307 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2308 bigPadding = bigKernel
2309
2310 pad_shape = [0] * (len(k_shape) * 2)
2311 stride_shape = [1] * len(k_shape)
2312 # The point at which input dimension combined with the stride will
2313 # create large output sizes!
2314 LARGE_SIZE = 2
2315 for idx in range(len(k_shape)):
2316 pad_offset = idx * 2
2317 if k_shape[idx] == bigKernel:
2318 # Set large stride
2319 stride_shape[idx] = bigKernel
2320 # Use negative output padding to reduce shape size
2321 pad_shape[pad_offset] = -(bigPadding - 1)
2322 if ifm_shape[idx + 1] > LARGE_SIZE:
2323 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2324 else:
2325 # The other dimension should be the bigKernel
2326 alt_idx = 1 - idx
2327 if (
2328 k_shape[alt_idx] == bigKernel
2329 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2330 ):
2331 # As the input is small, the large stride won't
2332 # affect the output so we can add some padding
2333 pad_shape[pad_offset + 1] = bigPadding
2334
2335 strides = {tuple(stride_shape)}
2336 paddings = {tuple(pad_shape)}
2337
2338 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002339 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002340
2341 n = 0
2342 for s in sorted(list(strides)):
2343 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07002344 if n % sparsity == 0:
2345 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002346 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2347 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07002348 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002349
Jeremy Johnson95a67102024-01-10 14:16:39 +00002350 # N*OH*OW*OC
2351 dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
2352 args_dict = {
2353 "acc_type": accum_dtype,
2354 "stride": s,
2355 "pad": p,
2356 "kernel": k_shape,
2357 "ks": k_size,
2358 "dot_products": dots,
2359 "shape": ifm_shape,
2360 "out_shape": os,
2361 }
2362
Jeremy Johnson0c716862023-04-13 17:18:19 +01002363 # Support for larger values than 9 needs different delimiter
2364 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07002365 arg_list.append(
2366 (
James Ward8b390432022-08-12 20:48:56 +01002367 "acc{}_st{}_pad{}_os{}".format(
2368 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002369 delim.join([str(x) for x in s]),
2370 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07002371 "x".join([str(x) for x in os]),
2372 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00002373 args_dict,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002374 )
TatWai Chong24594f52022-06-08 00:48:04 -07002375 )
2376 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002377
Jeremy Johnson95a67102024-01-10 14:16:39 +00002378 arg_list = TosaArgGen._add_data_generators(
2379 testGen,
2380 opName,
2381 dtypes[0],
2382 arg_list,
2383 error_name,
2384 )
2385 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002386 return arg_list
2387
2388 @staticmethod
2389 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002390 rank = len(shapeList[0])
2391
2392 # Exhaustively test combinations of padding on each side of each dimension
2393 # - the range of padding values is defined by pad_min and pad_max
2394 # - for padding >9, the name format needs to be more distinctive
2395 pad_min, pad_max = 0, 1
2396 pad_values = [x for x in range(pad_min, pad_max + 1)]
2397 if error_name == ErrorIf.PadSmallerZero:
2398 pad_values = [x for x in range(-2, 0)]
2399 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2400 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2401
2402 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
2403 pad_const_int = testGen.getRandNumberDType(dtype)
2404 pad_const_fp = 0
Won Jeon2c34b462024-02-06 18:37:00 +00002405 elif dtype in (
2406 DType.FP16,
2407 DType.BF16,
2408 DType.FP32,
2409 DType.FP8E4M3,
2410 DType.FP8E5M2,
2411 ):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002412 pad_const_int = 0
2413 pad_const_fp = testGen.getRandNumberDType(dtype)
2414 else:
2415 return []
2416
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002417 list_shape_pad_values = list(shape_pad_values)
2418 # If we are producing tests for rank 6 or greater use sparsity
2419 if len(list_shape_pad_values) > 1024:
2420 sparsity_factor = 2 if error_name else 120
2421 sparsity = TosaArgGen._calculate_sparsity(
2422 len(list_shape_pad_values), sparsity_factor
2423 )
2424 else:
2425 sparsity = 1
2426
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002427 # Build arg list
2428 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002429 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002430 paddings = list(paddings)
2431 args_valid = True
2432
2433 if error_name == ErrorIf.PadSmallerZero:
2434 # Prevent negative output shapes while ensuring still testing for negative padding
2435 for i in range(rank):
2436 dim_after_padding = (
2437 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2438 )
2439 if dim_after_padding < 1:
2440 paddings[i] = (0, 0)
2441 if all([p > -1 for p in paddings[i]]):
2442 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002443 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01002444 name = "pad"
2445 for r in range(rank):
2446 before, after = paddings[r]
2447 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002448 args_dict = {
2449 "pad": np.array(paddings),
2450 "pad_const_int": pad_const_int,
2451 "pad_const_fp": pad_const_fp,
2452 }
2453 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002454
2455 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
2456 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002457
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002458 arg_list = TosaArgGen._add_data_generators(
2459 testGen,
2460 opName,
2461 dtype,
2462 arg_list,
2463 error_name,
2464 )
2465
2466 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002467 return arg_list
2468
2469 @staticmethod
2470 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
2471 arg_list = []
2472
2473 shape = shapeList[0]
2474 if error_name != ErrorIf.WrongRank:
2475 assert len(shape) == 4
2476
Jeremy Johnson0c716862023-04-13 17:18:19 +01002477 test_level8k = testGen.args.level8k and error_name is None
2478
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002479 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002480 startKernel = 2
2481 startPad = 0
2482 if not test_level8k:
2483 # Generate comprehensive argument lists
2484 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2485 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2486 # Stride must be greater than 1 to force non-integer error
2487 s_vals = [
2488 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2489 ]
2490 strides = {x for x in itertools.product(*([s_vals] * 2))}
2491 k_vals = [
2492 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2493 ]
2494 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2495 max_dim_size = None
2496 else:
2497 # Only test 8k levels
2498 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2499 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2500 strides = {(1, bigStride), (bigStride, 4)}
2501 kernels = {(1, bigKernel), (bigKernel, 3)}
2502 paddings = set()
2503 for s in sorted(list(strides)):
2504 for k in sorted(list(kernels)):
2505 padding = []
2506 for idx in range(len(k)):
2507 total_padding = s[idx] - shape[idx + 1] + k[idx]
2508 while total_padding < 0:
2509 # Must meet: shape + padding > kernel
2510 total_padding += s[idx]
2511 if total_padding < k[idx]:
2512 padding.extend([0, total_padding])
2513 else:
2514 # Note this may produce padding >= k[idx] which is not
2515 # allowed - but will be ignored in the creation loop below
2516 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2517 paddings.add(tuple(padding))
2518 # Create a limit for the output dimensions size
2519 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002520
James Ward8b390432022-08-12 20:48:56 +01002521 if opName == "max_pool2d":
2522 accum_dtypes = [None] # max_pool has no accumulate dtype
2523 elif dtype == DType.INT8 or dtype == DType.INT16:
2524 accum_dtypes = [DType.INT32]
2525 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002526 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002527 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002528 accum_dtypes = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00002529 elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
2530 accum_dtypes = [DType.FP16]
James Ward8b390432022-08-12 20:48:56 +01002531 elif error_name is None:
2532 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2533 else:
2534 # Set to something for the ErrorIf case which has
2535 # incorrect input data-type
2536 accum_dtypes = [DType.INT32]
2537
Jeremy Johnson01e1c1c2024-02-07 16:09:09 +00002538 if error_name == ErrorIf.WrongAccumulatorType:
2539 accum_dtypes = list(gtu.usableDTypes(excludes=accum_dtypes))
2540
Jeremy Johnson0c716862023-04-13 17:18:19 +01002541 if not test_level8k:
2542 if testGen.args.oversize:
2543 # add some oversize argument values
2544 bigStride = 7
2545 bigKernel = 9
2546 strides.update(
2547 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002548 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002549 kernels.update(
2550 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2551 )
2552 if max(shape) < 64:
2553 # padding must be less than the kernel size
2554 bigPadding = bigKernel - 1
2555 paddings.update(
2556 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2557 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002558
Jeremy Johnson0c716862023-04-13 17:18:19 +01002559 # There are too many parameter combinations, so generate them sparsely,
2560 # very sparse for negative tests
2561 sparsity_factor = 2 if error_name else 500
2562 sparsity = (
2563 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2564 )
2565 else:
2566 # We have already limited test output combinations for 8k tests
2567 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002568
James Ward8b390432022-08-12 20:48:56 +01002569 arg_str = (
2570 "acc{}_st{}_kern{}_pad{}"
2571 if accum_dtypes[0] is not None
2572 else "st{}_kern{}_pad{}"
2573 )
2574
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002575 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002576 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002577 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002578
2579 # Support for larger values than 9 needs different delimiter
2580 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002581 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002582 delim.join([str(x) for x in stride]),
2583 delim.join([str(x) for x in kern]),
2584 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002585 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002586 args_dict = {
2587 "stride": stride,
2588 "pad": pad,
2589 "kernel": kern,
2590 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002591 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002592 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2593 }
James Ward8b390432022-08-12 20:48:56 +01002594
2595 if accum is not None:
2596 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002597 args_dict["acc_type"] = accum
2598 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002599
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002600 n = 0
James Ward8b390432022-08-12 20:48:56 +01002601 for a in accum_dtypes:
2602 for s in sorted(list(strides)):
2603 for p in sorted(list(paddings)):
2604 for k in sorted(list(kernels)):
2605 if error_name in [
2606 ErrorIf.StrideSmallerOne,
2607 ErrorIf.KernelSmallerOne,
2608 ErrorIf.PadSmallerZero,
2609 ErrorIf.PadLargerEqualKernel,
2610 ]:
2611 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
2612 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002613 )
James Ward8b390432022-08-12 20:48:56 +01002614 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002615 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002616 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002617 )
James Ward8b390432022-08-12 20:48:56 +01002618 elif (
2619 n % sparsity == 0
2620 # padding must not exceed the kernel size
2621 and p[0] < k[0]
2622 and p[1] < k[0]
2623 and p[2] < k[1]
2624 and p[3] < k[1]
2625 # the padded shape must exceed the kernel size
2626 and (shape[1] + p[0] + p[1]) > k[0]
2627 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002628 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002629 partial_h = shape[1] + p[0] + p[1] - k[0]
2630 partial_w = shape[2] + p[2] + p[3] - k[1]
2631 remainder_h = partial_h % s[0]
2632 remainder_w = partial_w % s[1]
2633 output_h = partial_h // s[0] + 1
2634 output_w = partial_w // s[1] + 1
2635 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01002636 if (
2637 # the parameters must produce integer exact output
2638 error_name != ErrorIf.PoolingOutputShapeNonInteger
2639 and remainder_h == 0
2640 and remainder_w == 0
2641 ) or (
2642 error_name == ErrorIf.PoolingOutputShapeNonInteger
2643 and (remainder_h != 0 or remainder_w != 0)
2644 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002645 if (
2646 max_dim_size is not None
2647 and max(output_h, output_w) > max_dim_size
2648 ):
2649 # Test will consume too much memory - skip it
2650 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002651 # Dot products = N*OH*OW*C
2652 dp = gtu.product(
2653 (shape[0], output_h, output_w, shape[3])
2654 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002655 arg_list.append(
2656 get_arg_list_element(a, s, p, k, dp, shape)
2657 )
James Ward8b390432022-08-12 20:48:56 +01002658 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002659
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002660 # Now add data generator types
2661 arg_list = TosaArgGen._add_data_generators(
2662 testGen,
2663 opName,
2664 dtype,
2665 arg_list,
2666 error_name,
2667 )
2668
2669 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002670 return arg_list
2671
2672 @staticmethod
2673 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
2674 arg_list = []
2675
2676 # Enumerate the output types here
2677 if error_name == ErrorIf.WrongOutputType:
2678 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
2679 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002680 dtypeList = [
2681 DType.BOOL,
2682 DType.INT16,
2683 DType.INT32,
2684 DType.FP16,
2685 DType.BF16,
2686 DType.FP32,
2687 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002688 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002689 dtypeList = [
2690 DType.BOOL,
2691 DType.INT8,
2692 DType.INT32,
2693 DType.FP16,
2694 DType.BF16,
2695 DType.FP32,
2696 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002697 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002698 dtypeList = [
2699 DType.BOOL,
2700 DType.INT8,
2701 DType.INT16,
2702 DType.FP16,
2703 DType.BF16,
2704 DType.FP32,
2705 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002706 elif inDtype == DType.BOOL:
2707 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002708 elif inDtype == DType.FP16:
Won Jeon2c34b462024-02-06 18:37:00 +00002709 dtypeList = [
2710 DType.INT8,
2711 DType.INT16,
2712 DType.INT32,
2713 DType.FP32,
2714 DType.FP8E4M3,
2715 DType.FP8E5M2,
2716 ]
James Ward24dbc422022-10-19 12:20:31 +01002717 elif inDtype == DType.BF16:
Won Jeon2c34b462024-02-06 18:37:00 +00002718 dtypeList = [
2719 DType.INT8,
2720 DType.INT16,
2721 DType.INT32,
2722 DType.FP32,
2723 DType.FP8E4M3,
2724 DType.FP8E5M2,
2725 ]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002726 elif inDtype == DType.FP32:
Won Jeon2c34b462024-02-06 18:37:00 +00002727 dtypeList = [
2728 DType.INT8,
2729 DType.INT16,
2730 DType.INT32,
2731 DType.FP16,
2732 DType.BF16,
2733 DType.FP8E4M3,
2734 DType.FP8E5M2,
2735 ]
2736 elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
2737 dtypeList = [DType.FP16, DType.BF16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002738 elif error_name == ErrorIf.WrongInputType:
2739 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002740 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002741 else:
2742 raise Exception("Unexpected input dtype: {}".format(inDtype))
2743
2744 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002745 arg_list.append(
2746 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2747 )
2748
2749 # Now add data generator types
2750 arg_list = TosaArgGen._add_data_generators(
2751 testGen,
2752 opName,
2753 dtype,
2754 arg_list,
2755 error_name,
2756 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002757
2758 return arg_list
2759
2760 @staticmethod
2761 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
2762 arg_list = []
2763
2764 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002765 for outDtype in [
2766 DType.UINT8,
2767 DType.INT8,
2768 DType.INT16,
2769 DType.INT32,
2770 DType.UINT16,
2771 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002772 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002773 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002774 and error_name == ErrorIf.OutputZeroPointNotZero
2775 ):
2776 continue
2777 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002778 outDtype != DType.UINT16
2779 and error_name == ErrorIf.U16OutputZeroPointNotValid
2780 ) or (
2781 inDtype != DType.UINT16
2782 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002783 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002784 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002785 continue
2786 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002787 inDtype == DType.UINT8
2788 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002789 and error_name != ErrorIf.WrongOutputType
2790 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002791 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2792 continue
2793 if (
2794 inDtype not in [DType.INT8, DType.INT16]
2795 and outDtype == DType.UINT8
2796 and error_name != ErrorIf.WrongOutputType
2797 ):
2798 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2799 continue
2800 if (
2801 inDtype == DType.UINT16
2802 and outDtype != DType.INT16
2803 and error_name != ErrorIf.WrongOutputType
2804 ):
2805 # The only output dtype for UINT16 is INT16, skip all others
2806 continue
2807 if (
2808 inDtype != DType.INT16
2809 and outDtype == DType.UINT16
2810 and error_name != ErrorIf.WrongOutputType
2811 ):
2812 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002813 continue
2814 if (
2815 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002816 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002817 ):
2818 continue
2819
2820 for scale32 in [False, True]:
2821 if error_name == ErrorIf.ScaleTrue and not scale32:
2822 continue
2823 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2824 continue
2825 for double_round in [False, True]:
2826 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2827 continue
2828 for per_channel in [False, True]:
2829
2830 if (
2831 inDtype == DType.INT48
2832 and scale32
2833 and error_name != ErrorIf.ScaleTrue
2834 ):
2835 # Illegal condition. Must be scale32=False
2836 continue
2837 if (
2838 double_round
2839 and not scale32
2840 and error_name != ErrorIf.ScaleNotTrue
2841 ):
2842 # Illegal condition. ERROR_IF(!scale32 && double_round)
2843 continue
2844
2845 arg_list.append(
2846 (
2847 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002848 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002849 int(scale32),
2850 int(double_round),
2851 int(per_channel),
2852 ),
Jeremy Johnson587cc842024-02-08 11:45:44 +00002853 {
2854 "output_dtype": outDtype,
2855 "scale": scale32,
2856 "double_round": double_round,
2857 "per_channel": per_channel,
2858 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002859 )
2860 )
2861
Jeremy Johnson587cc842024-02-08 11:45:44 +00002862 arg_list = TosaArgGen._add_data_generators(
2863 testGen,
2864 opName,
2865 inDtype,
2866 arg_list,
2867 error_name,
2868 )
2869 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002870 return arg_list
2871
2872 @staticmethod
2873 def agMul(testGen, opName, shapeList, dtype, error_name=None):
2874 arg_list = []
2875
2876 if dtype is DType.INT32:
2877 for p in range(testGen.args.num_rand_permutations):
2878
2879 shift = testGen.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002880 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002881 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002882 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002883
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002884 arg_list = TosaArgGen._add_data_generators(
2885 testGen,
2886 opName,
2887 dtype,
2888 arg_list,
2889 error_name,
2890 )
2891 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002892 return arg_list
2893
2894 @staticmethod
2895 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
2896 arg_list = []
2897
Jeremy Johnson587cc842024-02-08 11:45:44 +00002898 for round in (True, False):
2899 args_dict = {
2900 "round": round,
2901 }
2902 arg_list.append((f"round{round}", args_dict))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002903
Jeremy Johnson587cc842024-02-08 11:45:44 +00002904 arg_list = TosaArgGen._add_data_generators(
2905 testGen,
2906 opName,
2907 dtype,
2908 arg_list,
2909 error_name,
2910 )
2911 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002912 return arg_list
2913
Luke Hutton57287132023-02-06 14:54:18 +00002914 @staticmethod
2915 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2916 arg_list = []
2917
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002918 shape = shapeList[0]
2919 dot_products = gtu.product(shape)
2920 ks = 2 * shape[1] * shape[2] # 2*H*W
2921 for inverse in (True, False):
2922 args_dict = {
2923 "dot_products": dot_products,
2924 "shape": shape,
2925 "ks": ks,
2926 "acc_type": dtype,
2927 "inverse": inverse,
2928 }
2929 arg_list.append((f"inverse{inverse}", args_dict))
Luke Hutton57287132023-02-06 14:54:18 +00002930
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002931 arg_list = TosaArgGen._add_data_generators(
2932 testGen,
2933 opName,
2934 dtype,
2935 arg_list,
2936 error_name,
2937 )
2938 # Return list of tuples: (arg_str, args_dict)
Luke Hutton57287132023-02-06 14:54:18 +00002939 return arg_list
2940
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002941 @staticmethod
2942 def agRFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2943 arg_list = []
2944
2945 shape = shapeList[0]
2946 dot_products = gtu.product(shape)
2947 ks = shape[1] * shape[2] # H*W
2948 args_dict = {
2949 "dot_products": dot_products,
2950 "shape": shape,
2951 "ks": ks,
2952 "acc_type": dtype,
2953 }
2954 arg_list.append(("", args_dict))
2955
2956 arg_list = TosaArgGen._add_data_generators(
2957 testGen,
2958 opName,
2959 dtype,
2960 arg_list,
2961 error_name,
2962 )
2963 # Return list of tuples: (arg_str, args_dict)
2964 return arg_list
2965
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002966 # Helper function for reshape. Gets some factors of a larger number.
2967 @staticmethod
2968 def getFactors(val, start=1):
2969 factors = []
2970
2971 for i in range(start, int(np.sqrt(val)) + 1):
2972 if (val % i) == 0:
2973 factors.append(i)
2974
2975 return factors
2976
2977 @staticmethod
2978 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
2979 arg_list = []
2980
2981 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002982 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002983 factors = TosaArgGen.getFactors(totalElements)
2984
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002985 # Find new shapes up to the number of permutations asked for
2986 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002987 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002988 # Rank from 1 to TOSA_TENSOR_MAX_RANK
2989 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002990 if len(factors) < newRank:
2991 continue
2992
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002993 # escape_counter limits the generation of new shapes to a reasonable time
2994 for escape_counter in range(100):
2995
2996 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002997 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002998 remainingElements = totalElements
2999 shuffledFactors = testGen.rng.permutation(factors)
3000 for i in range(1, newRank):
3001 # pick rank-1 factors
3002 newShape.append(shuffledFactors[0])
3003 remainingElements = remainingElements // shuffledFactors[0]
3004 shuffledFactors = testGen.rng.permutation(
3005 TosaArgGen.getFactors(remainingElements)
3006 )
3007 newShape.append(remainingElements)
3008
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003009 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003010 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003011 for name, args_dict in arg_list:
3012 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003013 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003014 break
3015
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00003016 if not duplicate:
3017 outShape = "x".join([str(x) for x in newShape])
3018 arg_list.append(
3019 (
3020 "perm{}_rank{}_out{}".format(p, newRank, outShape),
3021 {"new_shape": newShape},
3022 )
3023 )
3024 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003025 break
3026
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00003027 # Now add data generator types
3028 arg_list = TosaArgGen._add_data_generators(
3029 testGen,
3030 opName,
3031 dtype,
3032 arg_list,
3033 error_name,
3034 )
3035
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003036 return arg_list
3037
3038 @staticmethod
3039 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
3040 arg_list = []
3041
3042 ifm_shape = shapeList[0]
3043
3044 if error_name == ErrorIf.IndexOutsideBounds:
3045 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
3046 incorrect_small_index = range(-len(ifm_shape), 0)
3047 permutations = [p for p in itertools.permutations(incorrect_large_index)]
3048 permutations.extend(
3049 [p for p in itertools.permutations(incorrect_small_index)]
3050 )
3051 elif error_name == ErrorIf.IndexUsedTwice:
3052 # Create list with a duplicated index
3053 perm_range = list(range(len(ifm_shape)))
3054 index_choice = testGen.rng.choice(range(len(perm_range)))
3055 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
3056 permutations = [p for p in itertools.permutations(perm_range)]
3057
3058 else:
3059 # Get all permutations
3060 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
3061
3062 # Limit to possible permutations from shape dimension or argument setting
3063 limit = min(len(permutations), testGen.args.num_rand_permutations)
3064
3065 # Get random permutation generator that uses all permutations
3066 random_permutations = testGen.rng.permutation(permutations)
3067
3068 # Create list of required amount of permutations
3069 arg_list = [
evacha0198477222024-01-26 12:25:32 +00003070 ("perm{}".format(p), {"perms": random_permutations[p].tolist()})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003071 for p in range(limit)
3072 ]
evacha0198477222024-01-26 12:25:32 +00003073 # Now add data generator types
3074 arg_list = TosaArgGen._add_data_generators(
3075 testGen,
3076 opName,
3077 dtype,
3078 arg_list,
3079 error_name,
3080 )
3081 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003082 return arg_list
3083
3084 @staticmethod
3085 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
3086 arg_list = []
3087
3088 ifm_shape = shapeList[0]
3089 rank = len(ifm_shape)
3090
3091 for p in range(testGen.args.num_rand_permutations):
3092 start = []
3093 size = []
3094
3095 valid = True
3096
3097 for i in range(rank):
3098 if ifm_shape[i] > 1:
3099 start.append(testGen.randInt(0, ifm_shape[i]))
3100 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
3101
3102 # Invalid slice size?
3103 if size[i] == 0:
3104 valid = False
3105 else:
3106 start.append(0)
3107 size.append(1)
3108
3109 if valid:
3110 # If ERROR_IF test required then incorrect start, size will be returned
3111 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
3112 testGen, error_name, ifm_shape, start, size
3113 )
evacha017f7d4252024-01-24 12:08:09 +00003114 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
3115 # Now add data generator types
3116 arg_list = TosaArgGen._add_data_generators(
3117 testGen,
3118 opName,
3119 dtype,
3120 arg_list,
3121 error_name,
3122 )
3123 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003124 return arg_list
3125
3126 @staticmethod
3127 def agTile(testGen, opName, shapeList, dtype, error_name=None):
3128 arg_list = []
3129
3130 ifm_shape = shapeList[0]
3131 rank = len(ifm_shape)
3132
3133 for p in range(testGen.args.num_rand_permutations):
3134
3135 # Pick a few random, but small multiple values
3136 # because otherwise this has a tendency to generate
3137 # enormous tensors
3138 multiples = []
3139 for i in range(rank):
3140 if ifm_shape[i] > 1000:
3141 # Multiple of 1 if ifm_shape dimension is large to reduce
3142 # tensor size
3143 multiples.append(1)
3144 elif max(ifm_shape) > 1000:
3145 multiples.append(2)
3146 else:
3147 multiples.append(testGen.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003148 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003149
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003150 # Now add data generator types
3151 arg_list = TosaArgGen._add_data_generators(
3152 testGen,
3153 opName,
3154 dtype,
3155 arg_list,
3156 error_name,
3157 )
3158 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003159 return arg_list
3160
3161 @staticmethod
3162 def agResize(testGen, opName, shapeList, dtype, error_name=None):
3163 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003164 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003165
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003166 def get_aspect_ratio_resize_params():
3167 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
3168 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
3169 invert = testGen.rng.choice((False, True))
3170 letterbox = testGen.rng.choice((False, True))
3171
3172 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
3173 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
3174 scale_y_d = scale_x_d = 1
3175 offset_x = offset_y = 0
3176
3177 if letterbox:
3178 max_border = scale_y_n
3179 border_y = testGen.randInt(low=0, high=max_border)
3180 border_x = 0
3181 else:
3182 # Pillarboxing
3183 border_y = 0
3184 max_border = scale_x_n
3185 border_x = testGen.randInt(low=0, high=max_border)
3186
3187 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3188 offset = (offset_y, offset_x)
3189 border = (border_y, border_x)
3190
3191 return scale, offset, border
3192
3193 def get_upscale_downscale_params():
3194 valid_params = False
3195 while not valid_params:
3196 upscale = testGen.rng.choice((False, True))
3197
3198 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
3199 origin_sampling = testGen.rng.choice((False, True))
3200
3201 if upscale:
3202 shift = testGen.randInt(low=1, high=4)
3203 scale_x_d = scale_y_d = 1
3204 scale_x_n = scale_y_n = (
3205 1 << shift if origin_sampling else 2 << shift
3206 )
3207 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3208 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3209 else:
3210 scale_x_n = 1
3211 scale_y_n = 1
3212
3213 # Return list of valid scale_*_d values (max value 4) given input dim shape
3214 def get_valid_denom(ifm_dim):
3215 return [x for x in range(1, 5) if ifm_dim % x == 1]
3216
3217 # Generate list of valid downscale values and choose one randomly
3218 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3219 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3220
3221 if not valid_scale_y_ds and not valid_scale_x_ds:
3222 # Bad parameters, skip
3223 continue
3224
3225 if not valid_scale_y_ds:
3226 scale_y_d = 1
3227 else:
3228 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
3229
3230 if not valid_scale_x_ds:
3231 scale_x_d = 1
3232 else:
3233 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
3234
3235 border_x = border_y = 0
3236 offset_y = testGen.randInt(0, 16 * scale_y_n)
3237 offset_x = testGen.randInt(0, 16 * scale_x_n)
3238 valid_params = True
3239
3240 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3241 offset = (offset_y, offset_x)
3242 border = (border_y, border_x)
3243 return scale, offset, border
3244
3245 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003246 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3247 scale = scale_n / scale_d
3248 if scale > max_scale:
3249 factor = scale / max_scale
3250 new_scale_d = math.ceil(scale_d * factor)
3251 assert scale_n / new_scale_d <= max_scale
3252 scale_d = new_scale_d
3253 return scale_d
3254
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003255 # Scale
3256 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
3257 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
3258
3259 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
3260 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
3261
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003262 scale_y_d = fix_scale_to_max_scale(
3263 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3264 )
3265 scale_x_d = fix_scale_to_max_scale(
3266 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3267 )
3268
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003269 # Offsets and border within the scale
3270 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3271 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3272 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3273 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
3274
3275 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3276 offset = (offset_y, offset_x)
3277 border = (border_y, border_x)
3278 return scale, offset, border
3279
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003280 def get_level_8k_params():
3281 # Create 64x scale - 64/1 to 2048/32
3282 scale_d = testGen.randInt(
3283 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3284 )
3285 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3286 # Create half to fifth scaling
3287 scale_d_alt = testGen.randInt(low=2, high=6)
3288 scale_n_alt = 1
3289 switch = testGen.rng.choice((False, True))
3290 if switch:
3291 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3292 else:
3293 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3294
3295 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3296 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
3297 offset = (offset_y, offset_x)
3298 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
3299 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
3300 border = (border_y, border_x)
3301 return scale, offset, border
3302
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003303 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003304 # Exclude illegal {mode, type} configurations. Pick legal output types
3305 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3306 outputDTypeList = [DType.INT8]
3307 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3308 outputDTypeList = [DType.INT16]
3309 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3310 outputDTypeList = [DType.INT32]
3311 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3312 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003313 elif dtype == DType.FP16:
3314 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003315 elif dtype == DType.BF16:
3316 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003317 elif dtype == DType.FP32:
3318 outputDTypeList = [DType.FP32]
Won Jeon2c34b462024-02-06 18:37:00 +00003319 elif dtype == DType.FP8E4M3:
3320 outputDTypeList = [DType.FP8E4M3]
3321 elif dtype == DType.FP8E5M2:
3322 outputDTypeList = [DType.FP8E5M2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003323 elif error_name == ErrorIf.WrongInputType:
3324 # If an incorrect input type is used then we set a 'correct'
3325 # output type to avoid other errors
3326 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3327 else:
3328 continue
3329
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003330 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3331
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003332 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003333 perm = 0
3334 while perm < testGen.args.num_rand_permutations:
3335 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003336 if not testGen.args.level8k:
3337 _rnd_param_fn = testGen.rng.choice(
3338 (
3339 get_rand_params,
3340 get_upscale_downscale_params,
3341 get_aspect_ratio_resize_params,
3342 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003343 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003344 scale, offset, border = _rnd_param_fn()
3345 else:
3346 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003347
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003348 # Expand params for bounds-checking
3349 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3350 (offset_y, offset_x) = offset
3351 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003352
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003353 # Make sure output dimensions OH and OW are integers
3354 partial_output_y = (
3355 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3356 )
3357 partial_output_x = (
3358 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3359 )
3360 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003361 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003362 if (
3363 partial_output_y % scale_y_d == 0
3364 and partial_output_x % scale_x_d == 0
3365 ):
3366 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003367 if perm > 0:
3368 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003369 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003370 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003371 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003372 while partial_output_y % scale_y_d != 0:
3373 scale_y_d -= 1
3374 while partial_output_x % scale_x_d != 0:
3375 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003376 # Make sure we are still within max scaling
3377 if (
3378 scale_y_n / scale_y_d
3379 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3380 scale_x_n / scale_x_d
3381 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3382 # Skip the test as it is using too large a scaling factor
3383 if perm > 0:
3384 perm += 1
3385 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003386
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003387 output_y = partial_output_y // scale_y_d + 1
3388 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003389
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003390 if (
3391 output_y >= testGen.args.max_resize_output_dim
3392 or output_x >= testGen.args.max_resize_output_dim
3393 ) and error_name is None:
3394 # Skip positive test if output dim will be too high
3395 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003396 if not testGen.args.level8k or perm > 0:
3397 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003398 continue
3399
3400 if (
3401 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003402 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003403 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003404 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003405 ):
3406 # Output dimensions out of scope
3407 if error_name is not None and perm > 0:
3408 # As long as we have one ERROR_IF test, don't worry
3409 # about creating all the other permutations
3410 perm += 1
3411 continue
3412
3413 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3414 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003415 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003416 and output_y - scale_y_d < 1
3417 )
3418 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003419 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003420 and output_x - scale_x_d < 1
3421 )
3422 ):
3423 # Can't create a negative test with these params as it
3424 # will create invalid output size
3425 if perm > 0:
3426 perm += 1
3427 continue
3428
3429 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3430 offset = [offset_y, offset_x]
3431 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003432
3433 # Common for all data types
3434 if error_name is not None:
3435 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003436 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003437 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003438 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003439 outputDTypeNew,
3440 ) = TosaErrorIfArgGen.eiResizeErrorIf(
3441 testGen,
3442 error_name,
3443 mode,
3444 dtype,
3445 shapeList,
3446 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003447 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003448 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003449 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003450 )
3451 else:
3452 outputDTypeNew = outputDType
3453
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003454 arg_to_append = (
3455 arg_str.format(
3456 "N" if mode == ResizeMode.NEAREST else "B",
3457 testGen.typeStr(outputDTypeNew),
3458 scale[0],
3459 scale[1],
3460 scale[2],
3461 scale[3],
3462 offset[0],
3463 offset[1],
3464 border[0],
3465 border[1],
3466 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003467 {
3468 "mode": mode,
3469 "scale": scale,
3470 "offset": offset,
3471 "border": border,
3472 "output_dtype": outputDTypeNew,
3473 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003474 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003475 if arg_to_append in arg_list:
3476 # Skip already generated test params
3477 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003478
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003479 # Valid permutation
3480 perm += 1
3481 arg_list.append(arg_to_append)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003482
3483 # Now add data generator types
3484 arg_list = TosaArgGen._add_data_generators(
3485 testGen,
3486 opName,
3487 dtype,
3488 arg_list,
3489 error_name,
3490 )
3491 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003492 return arg_list
3493
3494 @staticmethod
3495 def agTable(testGen, opName, shapeList, dtype, error_name=None):
3496 arg_list = []
3497
3498 if dtype == DType.INT8:
3499 table = np.int32(
3500 testGen.rng.integers(low=-128, high=128, size=[256])
3501 ).tolist()
3502 else: # INT16
3503 table = np.int32(
3504 testGen.rng.integers(low=-32768, high=32768, size=[513])
3505 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003506 # Make sure all slopes are within REQUIRE min/max 16-bit int
3507 for idx in range(len(table) - 1):
3508 slope = table[idx + 1] - table[idx]
3509 # Alter the next table entry to force the slope to be ok
3510 if slope > 32767:
3511 table[idx + 1] -= slope - 32767
3512 if slope < -32768:
3513 table[idx + 1] -= slope + 32768
3514 slope = table[idx + 1] - table[idx]
3515 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003516 arg_list.append(
3517 (
3518 "",
Jeremy Johnson587cc842024-02-08 11:45:44 +00003519 {"table": table},
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003520 )
3521 )
Jeremy Johnson587cc842024-02-08 11:45:44 +00003522 # Now add data generator types
3523 arg_list = TosaArgGen._add_data_generators(
3524 testGen,
3525 opName,
3526 dtype,
3527 arg_list,
3528 error_name,
3529 )
3530 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003531 return arg_list
3532
3533 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
3534 # CondIf generates the condition values here.
3535 # Convert to tensors in the build function, along with the
3536 # then and else blocks
3537 arg_list = []
3538
3539 for c in [False, True]:
Jeremy Johnson587cc842024-02-08 11:45:44 +00003540 arg_list.append(("cond{}".format(int(c)), {"condition": c}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003541
Jeremy Johnson587cc842024-02-08 11:45:44 +00003542 # Now add data generator types
3543 arg_list = TosaArgGen._add_data_generators(
3544 testGen,
3545 opName,
3546 dtype,
3547 arg_list,
3548 error_name,
3549 )
3550 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003551 return arg_list
3552
3553 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
3554 # While loop: 0 iterations, 1, more than 1
3555 arg_list = []
3556
Jeremy Johnson587cc842024-02-08 11:45:44 +00003557 for iterations in [0, 1, 4]:
3558 arg_list.append(("iter{}".format(iterations), {"iterations": iterations}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003559
Jeremy Johnson587cc842024-02-08 11:45:44 +00003560 # Now add data generator types
3561 arg_list = TosaArgGen._add_data_generators(
3562 testGen,
3563 opName,
3564 dtype,
3565 arg_list,
3566 error_name,
3567 )
3568 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003569 return arg_list