blob: 4630f35cbc0ca7e4877051d31c30c1345dac3419 [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
19
20class TosaQuantGen:
21 """QuantizedInfo random generator helper functions.
22
23 Specify with 'qgen': in the operator defintion.
24 """
25
26 def __init__(self):
27 pass
28
29 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010031
32 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010033 if testGen.args.zeropoint is not None:
34 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035 return testGen.randInt(-128, 128)
36 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010037 if testGen.args.zeropoint is not None:
38 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 return testGen.randInt(0, 256)
40 elif error_name in [
41 ErrorIf.InputZeroPointNotZero,
42 ErrorIf.WeightZeroPointNotZero,
43 ErrorIf.OutputZeroPointNotZero,
44 ]:
45 zero_point = testGen.randInt(-128, 128)
46 if zero_point == 0:
47 zero_point = 1
48 return zero_point
49 return 0
50
51 @staticmethod
52 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010053 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000054 qinfo = [
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 TosaQuantGen.getZeroPoint(testGen, dtype),
57 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010058 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000059 qinfo = [
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
62 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010063 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000064 qinfo = [
65 TosaQuantGen.getZeroPoint(testGen, dtype),
66 TosaQuantGen.getZeroPoint(testGen, dtype),
67 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068 return qinfo
69
70 @staticmethod
71 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010072 if isinstance(dtype_or_dtypeList, list):
73 # a list of [input, weights, accumulator] dtypes
74 dtypeList = dtype_or_dtypeList
75 else:
76 # an int, [input, weights, accumulator] dtypes are the same
77 dtypeList = [dtype_or_dtypeList] * 3
78
79 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000080 qinfo = [
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
82 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
83 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000085 qinfo = [
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
87 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
88 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000090 qinfo = [
91 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
92 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
93 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 return qinfo
95
96 @staticmethod
97 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000099 qinfo = [
100 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
101 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
102 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000104 qinfo = [
105 TosaQuantGen.getZeroPoint(testGen, dtype),
106 TosaQuantGen.getZeroPoint(testGen, dtype),
107 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108 return qinfo
109
110 @staticmethod
111 def computeMultiplierAndShift(scaleFp, scale32):
112 # Derived from computeMultiplierAndShiftTosaScale32
113 # Provide a floating-point scaling factor and the scale32 parameter
114 # to compute the multiplier and shift
115
116 if scale32:
117 scaleBits = 31
118 else:
119 scaleBits = 15
120
121 m, shift = math.frexp(scaleFp)
122
123 if scaleFp < 0.0:
124 m = -m
125
126 multiplier = round(m * (1 << scaleBits))
127 assert multiplier <= (1 << scaleBits)
128
129 if multiplier == (1 << scaleBits):
130 multiplier = multiplier // 2
131 shift = shift + 1
132
133 shift = (-shift) + scaleBits
134 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
135 # scaleFp, scaleBits, m, multiplier, shift))
136
137 # Adjust multiplier such that shift is in allowed value range.
138 if shift == 0:
139 multiplier = multiplier // 4
140 shift = shift + 2
141 elif shift == 1:
142 multiplier = multiplier // 2
143 shift = shift + 1
144 elif shift == 63:
145 multiplier = multiplier * 2
146 shift = shift - 1
147
148 assert multiplier <= (1 << scaleBits)
149 assert shift >= 2 and shift <= 62
150
151 return multiplier, shift
152
153
154class TosaTensorGen:
155 """Tensor generators create a shape list for the placeholder and const tensor
156 data operands for the operator.
157
158 The actual random data is generated separately for each test.
159 """
160
161 def __init__(self):
162 pass
163
164 @staticmethod
165 def tgBasic(testGen, opName, rank, error_name=None):
166 pl, const = opName["operands"]
167 shape = testGen.makeShape(rank)
168
169 # Constrict the overall size of the shape when creating ERROR_IF tests
170 if error_name:
171 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
172
173 shape_list = []
174 for i in range(pl + const):
175 shape_list.append(shape.copy())
176
Luke Huttona4e48ca2023-02-22 11:53:48 +0000177 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 if error_name == ErrorIf.RankMismatch:
179 if rank == 1 and i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
181 elif i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
183
184 return shape_list
185
186 @staticmethod
187 def tgNHWC(testGen, opName, rank, error_name=None):
188 pl, const = opName["operands"]
189
190 if error_name != ErrorIf.WrongRank:
191 assert rank == 4
192
193 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000194 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name and error_name != ErrorIf.MaxDimExceeded:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000207 def tgGather(testGen, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100208 pl, const = opName["operands"]
209
210 assert pl == 2
211 assert const == 0
212 if error_name != ErrorIf.WrongRank:
213 assert rank == 3
214
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000215 values_shape = testGen.makeShape(rank)
216 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100217
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000218 N = values_shape[0]
219 W = testGen.makeDimension()
220 indices_shape = [N, W]
221
222 shape_list = [values_shape, indices_shape]
223 return shape_list
224
225 @staticmethod
226 def tgScatter(testGen, opName, rank, error_name=None):
227 pl, const = opName["operands"]
228
229 assert pl == 3
230 assert const == 0
231 if error_name != ErrorIf.WrongRank:
232 assert rank == 3
233
234 values_in_shape = testGen.makeShape(rank)
235 values_in_shape = testGen.constrictBatchSize(values_in_shape)
236
237 N = values_in_shape[0]
238 K = values_in_shape[1]
239 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100240
Jeremy Johnson194fe312023-12-07 14:17:57 +0000241 # Make sure W is not greater than K, as we can only write each output index
242 # once (having a W greater than K means that you have to repeat a K index)
243 W_min = min(testGen.args.tensor_shape_range[0], K)
244 W_max = min(testGen.args.tensor_shape_range[1], K)
245 W = testGen.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100246
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000247 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100248
249 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000250 shape_list.append(values_in_shape)
251 shape_list.append([N, W]) # indices
252 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100253
254 return shape_list
255
256 @staticmethod
257 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
258 shape = testGen.makeShape(rank)
259
260 pl, const = op["operands"]
261
262 shape_list = []
263
264 # Choose one of the inputs to broadcast
265 # Note: Simplifies OutputShaper code if we don't change first shape for errors
266 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Jerry Ge135c9552023-05-23 20:59:32 +0000267 fuzz_idx = testGen.randInt(0, rank)
268
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100269 for i in range(pl + const):
270 shape_bcast = shape.copy()
271
Jerry Ge135c9552023-05-23 20:59:32 +0000272 # To test broadcasting, the chosen fuzz index dimension should not be 1
273 if shape_bcast[fuzz_idx] == 1:
274 shape_bcast[fuzz_idx] += 1
275
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100276 # If the chosen input, pick a random index to broadcast
277 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000278 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100279 # Add one rank to the shape (or more for rank of 1)
280 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
281 shape_bcast = np.concatenate(
282 (shape_bcast, testGen.makeShape(extra_ranks))
283 )
284 if rank != 1:
285 # Either keep the extra rank, or remove it
286 new_len = testGen.rng.choice([-2, len(shape_bcast)])
287 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000288 elif error_name == ErrorIf.BroadcastShapesMismatch:
289 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100290 else:
291 shape_bcast[fuzz_idx] = 1
292
293 shape_list.append(shape_bcast)
294
295 return shape_list
296
297 @staticmethod
298 def tgConv2D(testGen, op, rank, error_name=None):
299 pl, const = op["operands"]
300
301 if error_name != ErrorIf.WrongRank:
302 assert rank == 4
303
304 # IFM dimensions are NHWC
305 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000306 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100307
308 # Constrict the overall size of the shape when creating ERROR_IF tests
309 if error_name:
310 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
311 ifm_shape, max_dim=24, max_items=10000
312 )
313
314 # Get the filter height/width from the operator parameters
315 filter_hw = op["filter"]
316
317 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000318 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100319
320 # The filter dimensions are OHWI
321 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
322
323 # The bias is OC
324 bias_shape = np.asarray([ofm_depth])
325
326 return [ifm_shape, filter_shape, bias_shape]
327
328 @staticmethod
329 def tgConv3D(testGen, op, rank, error_name=None):
330 pl, const = op["operands"]
331
332 if error_name != ErrorIf.WrongRank:
333 assert rank == 5
334
335 # IFM dimensions are NDHWC
336 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000337 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100338
339 # Constrict the overall size of the shape when creating ERROR_IF tests
340 if error_name:
341 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
342 ifm_shape, max_dim=24, max_items=10000
343 )
344
345 # Get the filter depth/height/width from the operator parameters
346 filter_dhw = op["filter"]
347
348 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000349 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100350
351 # The filter dimensions are ODHWI
352 filter_shape = np.asarray(
353 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
354 )
355
356 # The bias is OC
357 bias_shape = np.asarray([ofm_channel])
358
359 return [ifm_shape, filter_shape, bias_shape]
360
361 @staticmethod
362 def tgTransposeConv2D(testGen, op, rank, error_name=None):
363 pl, const = op["operands"]
364
365 if error_name != ErrorIf.WrongRank:
366 assert rank == 4
367
368 # IFM dimensions are NHWC
369 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000370 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100371
372 # Constrict the overall size of the shape when creating ERROR_IF tests
373 if error_name:
374 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
375 ifm_shape, max_dim=24, max_items=10000
376 )
377
378 # Get the filter height/width from the operator parameters
379 filter_hw = op["filter"]
380
381 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000382 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100383
384 # The filter dimensions are OHWI
385 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
386
387 # The bias is OC
388 bias_shape = np.asarray([ofm_depth])
389
390 return [ifm_shape, filter_shape, bias_shape]
391
392 @staticmethod
393 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
394 pl, const = op["operands"]
395
396 if error_name != ErrorIf.WrongRank:
397 assert rank == 4
398 assert pl == 1 and const == 2
399
400 # IFM dimensions are NHWC
401 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000402 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100403
404 # Constrict the overall size of the shape when creating ERROR_IF tests
405 if error_name:
406 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
407 ifm_shape, max_dim=24, max_items=10000
408 )
409
410 # Get the filter height/width from the operator parameters
411 # Filter is KH, HW, C, M
412 filter_hw = op["filter"]
413
414 # Generate a random OFM depth, but don't let it get too big because
415 # the output depth is M * C
416 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000417 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100418 ) + 1
419
420 # The filter dimensions are HWCM
421 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
422
423 # The bias is M * C
424 bias_shape = np.asarray([ifm_shape[3] * filter_m])
425
426 return [ifm_shape, filter_shape, bias_shape]
427
428 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000429 def tgFFT2d(testGen, op, rank, error_name=None):
430 pl, const = op["operands"]
431
432 if error_name != ErrorIf.WrongRank:
433 assert rank == 3
434 assert pl == 2 and const == 0
435
436 # IFM dimensions are NHW
437 ifm_shape = testGen.makeShape(rank)
438
439 # Select nearest lower power of two from input height and width
440 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
441 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
442
443 # Constrict the overall size of the shape when creating ERROR_IF tests
444 if error_name:
445 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
446
447 # Generate an invalid kernel that is not a power of two
448 if error_name == ErrorIf.KernelNotPowerOfTwo:
449 inc_h = 2 if ifm_shape[1] == 1 else 1
450 inc_w = 2 if ifm_shape[2] == 1 else 1
451 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
452 selected_inc = testGen.rng.choice(inc_choices)
453 ifm_shape[1] += selected_inc[0]
454 ifm_shape[2] += selected_inc[1]
455
456 ifm_shape = testGen.constrictBatchSize(ifm_shape)
457
458 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
459 if error_name == ErrorIf.FFTInputShapeMismatch:
460 modify_shape = testGen.rng.choice([0, 1])
461 # Only modify kernel (H, W)
462 modify_dim = testGen.rng.choice([1, 2])
463 ifm_shapes[modify_shape][modify_dim] *= 2
464
465 return [ifm_shapes[0], ifm_shapes[1]]
466
467 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000468 def tgRFFT2d(testGen, op, rank, error_name=None):
469 pl, const = op["operands"]
470
471 if error_name != ErrorIf.WrongRank:
472 assert rank == 3
473 assert pl == 1 and const == 0
474
475 # IFM dimensions are NHW
476 ifm_shape = testGen.makeShape(rank)
477
478 # Select nearest lower power of two from input height and width
479 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
480 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
481
482 # Constrict the overall size of the shape when creating ERROR_IF tests
483 if error_name:
484 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
485
486 # Generate an invalid kernel that is not a power of two
487 if error_name == ErrorIf.KernelNotPowerOfTwo:
488 # We must increment by 2 if current size is 1
489 inc_h = 2 if ifm_shape[1] == 1 else 1
490 inc_w = 2 if ifm_shape[2] == 1 else 1
491 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
492 selected_inc = testGen.rng.choice(inc_choices)
493 ifm_shape[1] += selected_inc[0]
494 ifm_shape[2] += selected_inc[1]
495
James Ward30124a82023-02-02 14:56:33 +0000496 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000497
498 return [ifm_shape]
499
500 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100501 def tgFullyConnected(testGen, op, rank, error_name=None):
502 pl, const = op["operands"]
503
504 if error_name != ErrorIf.WrongRank:
505 assert rank == 2
506
507 input_shape = testGen.makeShape(rank)
508
509 # Constrict the overall size of the shape when creating ERROR_IF tests
510 if error_name:
511 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
512
513 filter_oc = testGen.rng.integers(
514 low=testGen.args.tensor_shape_range[0],
515 high=testGen.args.tensor_shape_range[1],
516 size=1,
517 )[0]
518 filter_shape = np.asarray([filter_oc, input_shape[1]])
519
520 bias_shape = np.asarray([filter_oc])
521
522 return [input_shape, filter_shape, bias_shape]
523
524 @staticmethod
525 def tgMatmul(testGen, op, rank, error_name=None):
526 pl, const = op["operands"]
527
528 if error_name != ErrorIf.WrongRank:
529 assert rank == 3
530 assert pl == 2 and const == 0
531
532 a_shape = testGen.makeShape(rank)
533
534 # Constrict the overall size of the shape when creating ERROR_IF tests
535 if error_name:
536 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
537
538 # Get a random number for b_oc even if target shape is defined
539 b_oc = np.int32(
540 testGen.rng.integers(
541 low=testGen.args.tensor_shape_range[0],
542 high=testGen.args.tensor_shape_range[1],
543 size=1,
544 )
545 )[0]
546 # If N or H is large let b_oc be 1 to reduce output tensor size
547 if max(a_shape) > 1000:
548 b_oc = 1
549
550 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
551 return [a_shape, b_shape]
552
553 @staticmethod
554 def tgConcat(testGen, opName, rank, error_name=None):
555 pl, const = opName["operands"]
556 shape = testGen.makeShape(rank)
557
558 # Create extra tensors to concat.
559 # Take into account value of pl when getting maximum number of concats
560 num_tensors = testGen.randInt(0, 4)
561 shape_list = []
562 for i in range(pl + const + num_tensors):
563 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
564 remove = testGen.rng.choice([True, False])
565 wrongShape = shape.copy()
566
567 if remove and len(shape) > 1:
568 wrongShape = wrongShape[1:]
569 else:
570 wrongShape = list(wrongShape)
571 wrongShape.append(testGen.rng.integers(1, 10))
572
573 shape_list.append(wrongShape)
574 else:
575 shape_list.append(shape.copy())
576
577 return shape_list
578
579 @staticmethod
580 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
581 if error_name in [
582 ErrorIf.AxisSmallerZero,
583 ErrorIf.AxisLargerRank,
584 ErrorIf.ConcatInputRankMismatch,
585 ]:
586 return shapeList
587
588 # Split concat shape along axis to allow for multiple const inputs
589 # without making too many large tensors
590 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
591 # If axis can't be split we still need to invalidate other dimensions
592 if error_name == ErrorIf.ConcatInputDimMismatch:
593 for shape in shapeList[1:]:
594 # Negative test shapeLists are created individually for each test,
595 # so no need to copy the shape before altering it.
596 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
597 return shapeList
598
599 # Create copy of shape we are going to split (so we don't alter shapeList)
600 shape = shapeList[0].copy()
601 # Add original shape as first input
602 new_shapeList = [shape.copy()]
603 length_on_axis = shape[axis]
604 remaining_length = length_on_axis
605 for i in range(len(shapeList) - 2):
606 # Calculate split on axis and remaining value
607 split_shape_val = int(shape[axis] / 2)
608 remaining_length = remaining_length - split_shape_val
609
610 # Append new shape, and set remaining shape
611 shape[axis] = split_shape_val
612 new_shapeList.append(shape.copy())
613
614 # invalidate dimensions
615 if error_name == ErrorIf.ConcatInputDimMismatch:
616 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
617 else:
618 shape[axis] = remaining_length
619
620 if i == len(shapeList) - 3:
621 new_shapeList.append(shape.copy())
622
623 return new_shapeList
624
625
626class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100627 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100628
629 def __init__(self):
630 pass
631
Jeremy Johnson1271c442023-09-05 11:39:26 +0100632 class TVGInfo:
633 """Enhanced tensor values information including data gen dict."""
634
635 def __init__(self, tensorList, dataGenDict):
636 self.tensorList = tensorList
637 self.dataGenDict = dataGenDict
638
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100639 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000640 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100641 pCount, cCount = op["operands"]
642
643 tens = []
644 tens.extend(
645 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
646 )
647 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
648
649 return tens
650
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100651 # Default high value for random numbers
652 TVG_FLOAT_HIGH_VALUE = {
653 DType.FP32: (1 << 128) - (1 << (127 - 23)),
654 DType.FP16: (1 << 16) - (1 << (15 - 10)),
655 DType.BF16: (1 << 128) - (1 << (127 - 7)),
656 }
657
Jeremy Johnson30476252023-11-20 16:15:30 +0000658 # Default lowest normal values for random numbers
659 TVG_FLOAT_LOW_VALUE = {
660 DType.FP32: np.exp2(-126),
661 DType.FP16: np.exp2(-14),
662 DType.BF16: np.exp2(-126),
663 }
664
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100665 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +0000666 def _get_data_range(testGen, dtype, highValueLookup, lowValueLookup=None):
667 # Return a tuple of (low,high) data range values for the given data
668 # type using a combination of per operator table limits, data limits
669 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000670 if dtype in highValueLookup:
Jeremy Johnson30476252023-11-20 16:15:30 +0000671 type_range = testGen.getDTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000672 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000673 if lowValueLookup is not None and dtype in lowValueLookup:
674 low_val = lowValueLookup[dtype]
675 else:
676 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000677 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000678 # respecting the default ranges if more/less than the low/high
679 # values
680 data_range = (
681 max(low_val, type_range[0]),
682 min(high_val, type_range[1]),
683 )
684 if data_range[0] > data_range[1]:
685 # Invalid data range from low to high created due to user
686 # constraints revert to using internal ranges as they are
687 # known to work
688 msg = f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
689 warnings.warn(msg)
690 data_range = (low_val, high_val)
691 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000692 return None
693
694 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100695 def tvgLazyGenDefault(
696 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
697 ):
698 # Variable inputs versus constants
699 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000700 if "p_count" in argsDict:
701 # Override for operators like CONCAT
702 pCount = argsDict["p_count"]
703 cCount = argsDict["c_count"]
704 assert pCount + cCount == len(
705 shapeList
706 ), "Placeholders & Constant tensors must match shapes list"
707
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000708 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100709
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100710 if (
711 error_name is not None
712 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100713 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100714 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000715 # Fall back to internal data gen when dealing with unsupported types or ops
716 data_range = argsDict["data_range"] if "data_range" in argsDict else None
717 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000718 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000719 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000720 if "data_range_list" in argsDict:
721 data_range = argsDict["data_range_list"][idx]["range"]
722 roundMode = (
723 "round" in argsDict["data_range_list"][idx]
724 and argsDict["data_range_list"][idx]["round"] is True
725 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000726 if data_range is not None and dtype not in (
727 DType.FP16,
728 DType.FP32,
729 DType.BF16,
730 ):
731 # Change from inclusive to exclusive range
732 data_range = (data_range[0], data_range[1] + 1)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000733 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000734
735 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
736 arr = np.int64(argsDict["fixed_data"][idx])
737 else:
738 arr = testGen.getRandTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000739 if roundMode:
740 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000741 if idx < pCount:
742 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
743 else:
744 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100745
Jeremy Johnson1271c442023-09-05 11:39:26 +0100746 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
747
748 # Create data generator meta-data
749 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100750 tens_data = {
751 "version": "0.1",
752 "tensors": {},
753 }
754 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100755 for idx, shape in enumerate(shapeList):
756
757 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000758 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
759 tens_meta["generator"] = gtu.DataGenType(
760 gtu.DataGenType.FIXED_DATA
761 ).name
762 else:
763 tens_meta["generator"] = gtu.DataGenType(dg_type).name
764
Jeremy Johnson1271c442023-09-05 11:39:26 +0100765 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
766 tens_meta["shape"] = [int(i) for i in shape]
767 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100768 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100769
770 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100771 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100772 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100773 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100774
775 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
776 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000777 if (
778 tens_meta["generator"]
779 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
780 ):
781 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
782 tens_meta["fixed_data_info"] = info
783 else:
784 # TODO - generate seed for this generator based on test
785 info["rng_seed"] = 42
Jeremy Johnson30476252023-11-20 16:15:30 +0000786
Won Jeon64e4bfe2024-01-18 06:31:55 +0000787 data_range = None
788 if "data_range_list" in argsDict:
789 data_range = argsDict["data_range_list"][idx]["range"]
790 if "round" in argsDict["data_range_list"][idx]:
791 info["round"] = argsDict["data_range_list"][idx]["round"]
792 elif "data_range" in argsDict:
793 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000794
Won Jeon64e4bfe2024-01-18 06:31:55 +0000795 if data_range is None:
796 data_range = testGen.getDTypeRange(
797 dtypeList[idx], high_inclusive=True
798 )
799 info["range"] = [str(v) for v in data_range]
800 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100801 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
802 info = {}
803 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100804 info["ks"] = int(argsDict["ks"])
805 if "acc_type" in argsDict:
806 # Convert type number into JSON name
807 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
808 "json"
809 ]
810 if "kernel" in argsDict:
811 info["kernel"] = [int(k) for k in argsDict["kernel"]]
812 if "axis" in argsDict:
813 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100814 tens_meta["dot_product_info"] = info
815 else:
816 # TODO - other data gen type
817 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100818
819 # Using the finished generate config meta data - generate the data if
820 # needed and assign a tensor name from the serializer
821
822 # Need to generate data when not lazy or for the bias tensor as we need
823 # to work out if the bias data is non-zero for compliance
824 if not testGen.args.lazy_data_gen or (
825 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
826 ):
827 # Give this tensor a temporary name until we get one from the serializer
828 temp_name = f"placeholder_{idx}"
829 dg_tens_meta[temp_name] = tens_meta
830 # Create data now using the temporary name to access meta details
831 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000832 if tens_meta["data_type"] == "SHAPE":
833 # Tensor type SHAPE and Numpy file type must be the same
834 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100835 # Remove the item as we will give it the correct name later
836 del dg_tens_meta[temp_name]
837
838 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
839 # The KS value used by compliance verification is altered when the
840 # bias data is non-zero
841 if max(abs(data)) > 0.0:
842 argsDict["ksb"] = argsDict["ks"] + 1
843
844 if testGen.args.lazy_data_gen:
845 data = None
846
847 if tens_meta["input_type"] == "VARIABLE":
848 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
849 else:
850 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
851
852 tens_ser_list.append(tens)
853 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100854 dg_tens_meta[tens.name] = tens_meta
855
Jeremy Johnson1271c442023-09-05 11:39:26 +0100856 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
857
858 @staticmethod
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000859 def tvgNegate(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100860 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000861 # Integer test
862 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100863 pCount, cCount = op["operands"]
864 assert (
865 pCount == 1 and cCount == 0
866 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100867 # Must create tensors with values within accumulator (int32) negatable
868 # range
869 max_val = (1 << 31) - 1
870 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100871 arr = np.int32(
872 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
873 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000874 tens_ser_list = []
875 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100876 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
877 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000878 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100879 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000880 # ERROR_IF or floating point test
881 return TosaTensorValuesGen.tvgLazyGenDefault(
882 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100883 )
884
Jeremy Johnson30476252023-11-20 16:15:30 +0000885 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000886 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
887 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
888 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
889 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
890 }
891
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100892 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000893 def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon74342e52024-01-09 00:34:40 +0000894 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000895 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100896 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000897 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100898 pCount, cCount = op["operands"]
899 assert (
900 pCount == 2 and cCount == 0
901 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000902 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000903 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
904 data_range = testGen.args.tensor_shape_range
905 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0], data_range)
906 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100907 if add:
908 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
909 else:
910 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
911
912 # Work out the saturation limits
913 max_i32 = (1 << 31) - 1
914 min_i32 = -(1 << 31)
915 max_arr = np.full(shapeList[1], max_i32)
916 min_arr = np.full(shapeList[1], min_i32)
917
918 # Find how much values exceed the maximum/minimums
919 sat_max_arr = np.maximum(res_arr - max_arr, 0)
920 sat_min_arr = np.minimum(res_arr - min_arr, 0)
921
922 if not add:
923 # Swap saturation values and negate values as we need to perform opposite operations
924 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
925
926 # Create new array of unsaturated values by clipping values as needed
927 b_unsat_arr = b_arr
928 if (sat_max_arr != 0).any():
929 # Clip values that cause saturation
930 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
931 # Reduce axes in unsaturated tensor to match original tensor
932 for axis, dim in enumerate(b_arr.shape):
933 if dim != b_unsat_arr.shape[axis]:
934 assert (
935 dim == 1
936 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
937 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
938
939 if (sat_min_arr != 0).any():
940 # Clip values that cause saturation
941 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
942 # Reduce axes in unsaturated tensor to match original tensor
943 for axis, dim in enumerate(b_arr.shape):
944 if dim != b_unsat_arr.shape[axis]:
945 assert (
946 dim == 1
947 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
948 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
949
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000950 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100951 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
952 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000953 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100954 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
955 )
956
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000957 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100958 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000959 # ERROR_IF or floating point test
960 data_range = TosaTensorValuesGen._get_data_range(
961 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
962 )
963 if data_range:
964 argsDict["data_range"] = data_range
965
966 return TosaTensorValuesGen.tvgLazyGenDefault(
967 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100968 )
969
970 @staticmethod
971 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000972 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100973 ):
974 if dtypeList[0] in (
975 DType.INT32,
976 DType.INT16,
977 DType.INT8,
978 ):
979 # Limit input tensors with cond_if_binary or while_loop to stop
980 # saturation of add/sub ops with int32 and keep all logical shift
981 # values between 0 to 31 for int16 or int8
982 pCount, cCount = op["operands"]
983 pRemain = pCount
984 placeholders = []
985 for idx, shape in enumerate(shapeList[:]):
986 if dtypeList[0] == DType.INT32:
987 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
988 else:
989 arr = np.int32(
990 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
991 )
992 if pRemain > 0:
993 placeholders.append(
994 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
995 )
996 pRemain -= 1
997 else:
998 placeholders.append(
999 testGen.ser.addConst(shape, dtypeList[idx], arr)
1000 )
1001
1002 return placeholders
1003 else:
1004 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001005 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001006 )
1007
1008 @staticmethod
1009 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001010 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001011 ):
1012 pCount, cCount = op["operands"]
1013 # Force value of operand[1] to be within [0, num_bits]
1014 assert (
1015 pCount == 2 and cCount == 0
1016 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1017
1018 placeholders = []
1019 for idx, shape in enumerate(shapeList[:]):
1020 if idx == 1:
1021 if dtypeList[idx] == DType.INT8:
1022 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
1023 elif dtypeList[idx] == DType.INT16:
1024 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
1025 elif dtypeList[idx] == DType.INT32:
1026 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
1027 elif error_name == ErrorIf.WrongInputType:
1028 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
1029 else:
1030 raise Exception("OpArithmeticRightShift: invalid input dtype")
1031 else:
1032 arr = testGen.getRandTensor(shape, dtypeList[idx])
1033 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
1034
1035 return placeholders
1036
1037 @staticmethod
Won Jeon64e4bfe2024-01-18 06:31:55 +00001038 def tvgReshape(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
1039 dtypeList[1] = DType.SHAPE
1040 shapeList[1] = [len(argsDict["new_shape"])]
1041 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1042 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1043
1044 return TosaTensorValuesGen.tvgLazyGenDefault(
1045 testGen, op, dtypeList, shapeList, argsDict, error_name
1046 )
1047
1048 @staticmethod
Tai Lye095da72024-01-25 22:00:18 +00001049 def tvgPad(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
1050 # argsDict["pad"] is 2D array, need to flatten it to get list of values
1051 pad_values = argsDict["pad"].flatten()
1052 dtypeList[1] = DType.SHAPE
1053 shapeList[1] = [len(pad_values)]
1054 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1055 argsDict["fixed_data"] = [None, pad_values]
1056
1057 return TosaTensorValuesGen.tvgLazyGenDefault(
1058 testGen, op, dtypeList, shapeList, argsDict, error_name
1059 )
1060
1061 @staticmethod
TatWai Chongf15bad82024-01-31 21:33:27 -08001062 def tvgSlice(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
1063 dtypeList[1] = DType.SHAPE
1064 shapeList[1] = [len(argsDict["start"])]
1065 dtypeList[2] = DType.SHAPE
1066 shapeList[2] = [len(argsDict["size"])]
1067 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1068 argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]]
1069
1070 return TosaTensorValuesGen.tvgLazyGenDefault(
1071 testGen, op, dtypeList, shapeList, argsDict, error_name
1072 )
1073
1074 @staticmethod
Won Jeon64e4bfe2024-01-18 06:31:55 +00001075 def tvgTile(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
1076 dtypeList[1] = DType.SHAPE
1077 shapeList[1] = [len(argsDict["multiples"])]
1078 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1079
1080 return TosaTensorValuesGen.tvgLazyGenDefault(
1081 testGen, op, dtypeList, shapeList, argsDict, error_name
1082 )
1083
1084 @staticmethod
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001085 def tvgSelect(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001086 # Set datatype of condition tensor to boolean
1087 dtypeList[0] = DType.BOOL
1088
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001089 return TosaTensorValuesGen.tvgLazyGenDefault(
1090 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001091 )
1092
1093 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001094 def tvgIntDiv(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001095 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001096 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001097 pCount, cCount = op["operands"]
1098 assert (
1099 pCount == 2 and cCount == 0
1100 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1101
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001102 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001103
1104 # Two invalid cases for Op.INTDIV:
1105 # 1. divisor == 0
1106 # 2. dividend == -(1<<31) and divisor == -1
1107 while True:
1108 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1109 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
1110
1111 if (divisor_arr == 0).any():
1112 continue
1113
1114 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1115 continue
1116
1117 break
1118
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001119 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001120 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1121 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001122 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001123 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1124 )
1125
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001126 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001127 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001128 return TosaTensorValuesGen.tvgLazyGenDefault(
1129 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001130 )
1131
Jeremy Johnson30476252023-11-20 16:15:30 +00001132 # Set the MUL data range to the square root of the largest value
1133 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001134 TVG_FLOAT_HIGH_VALUE_MUL = {
1135 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1136 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1137 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1138 }
1139
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001140 @staticmethod
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001141 def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1142 if error_name is not None or dtypeList[0] in (
1143 DType.FP16,
1144 DType.BF16,
1145 DType.FP32,
1146 ):
1147 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001148 data_range = TosaTensorValuesGen._get_data_range(
1149 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
1150 )
1151 if data_range:
1152 argsDict["data_range"] = data_range
1153
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001154 return TosaTensorValuesGen.tvgLazyGenDefault(
1155 testGen, opName, dtypeList, shapeList, argsDict, error_name
1156 )
1157 else:
1158 # Integer test
1159 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001160 pCount, cCount = op["operands"]
1161 assert (
1162 pCount == 2 and cCount == 0
1163 ), "Op.MUL must have 2 placeholders, 0 consts"
1164
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001165 tens_ser_list = []
1166
1167 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001168 if dtypeList[0] == DType.SHAPE:
1169 shift = 0
1170 else:
1171 shift = argsDict["shift"]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001172 if dtypeList[0] == DType.INT8:
1173 num_bits = 8
1174 elif dtypeList[0] == DType.INT16:
1175 num_bits = 16
Won Jeon74342e52024-01-09 00:34:40 +00001176 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001177 num_bits = 32
1178 elif error_name == ErrorIf.WrongInputType:
1179 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001180 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001181 raise Exception("OpMul: invalid input dtype")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001182
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001183 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001184 if dtypeList[idx] == DType.SHAPE:
1185 low = testGen.args.tensor_shape_range[0]
1186 high = testGen.args.tensor_shape_range[1]
1187 else:
1188 low = -(2 ** (num_bits - 1))
1189 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001190
1191 a_arr = np.int32(
1192 testGen.rng.integers(low=low, high=high, size=shapeList[0])
1193 )
1194 b_arr = np.int32(
1195 testGen.rng.integers(low=low, high=high, size=shapeList[1])
1196 )
1197
1198 i = 0
1199 while True:
1200
1201 a_arr_64 = a_arr.astype(np.int64)
1202 b_arr_64 = b_arr.astype(np.int64)
1203
1204 if shift > 0:
1205 rounding = 1 << (shift - 1)
1206 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001207 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001208 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001209
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001210 if (result_arr > -(2**31)).all() and (
1211 result_arr <= ((2**31) - 1)
1212 ).all():
1213 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001214
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001215 i = i + 1
1216 a_arr = a_arr // 2
1217 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001218
Won Jeon74342e52024-01-09 00:34:40 +00001219 if dtypeList[0] == DType.SHAPE:
1220 tens_ser_list.append(
1221 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1222 )
1223 tens_ser_list.append(
1224 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1225 )
1226 else:
1227 tens_ser_list.append(
1228 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1229 )
1230 tens_ser_list.append(
1231 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1232 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001233
1234 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001235
1236 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001237 def tvgConcat(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001238 count = len(shapeList) - testGen.args.num_const_inputs_concat
1239 if count < 1:
1240 count = 1
1241 if testGen.args.num_const_inputs_concat == 0:
1242 count = len(shapeList)
1243
Won Jeon74342e52024-01-09 00:34:40 +00001244 op = testGen.TOSA_OP_LIST[opName]
1245 if op["op"] == Op.CONCAT_SHAPE:
1246 # Set the axis to 0
1247 shapeList = TosaTensorGen.tgConcatConstInput(
1248 testGen, shapeList, 0, error_name
1249 )
1250 else:
1251 shapeList = TosaTensorGen.tgConcatConstInput(
1252 testGen, shapeList, argsDict["axis"], error_name
1253 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001254
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001255 # Override default pCount/cCount for operator
1256 argsDict["p_count"] = count
1257 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001258
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001259 return TosaTensorValuesGen.tvgLazyGenDefault(
1260 testGen, opName, dtypeList, shapeList, argsDict, error_name
1261 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001262
1263 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001264 def tvgLogicalShift(
1265 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1266 ):
1267 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001268 pCount, cCount = op["operands"]
1269 assert (
1270 pCount == 2 and cCount == 0
1271 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
1272 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1273 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001274 tens_ser_list = []
1275 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001276 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1277 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001278 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001279 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1280 )
1281
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001282 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001283
1284 @staticmethod
Jeremy Johnsona0150012023-11-15 15:52:06 +00001285 def tvgEqual(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1286 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1287 # Integer
1288 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001289 pCount, cCount = op["operands"]
1290 assert (
1291 pCount == 2 and cCount == 0
1292 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001293
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001294 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1295 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001296
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001297 # Using random numbers means that it will be very unlikely that
1298 # there are any matching (equal) values, therefore force that
1299 # there are twice the number of matching values as the tensor rank
1300 for num in range(0, len(shapeList[0]) * 2):
1301 a_index = []
1302 b_index = []
1303 # Choose an index in each axis for the whole shape
1304 for axis in range(0, len(shapeList[0])):
1305 # Index can be up to the largest dimension in both shapes
1306 index = np.int32(
1307 testGen.rng.integers(
1308 0, max(shapeList[0][axis], shapeList[1][axis])
1309 )
1310 )
1311 # Reduce the index down to a shape's dim for broadcasting
1312 a_index.append(min(shapeList[0][axis] - 1, index))
1313 b_index.append(min(shapeList[1][axis] - 1, index))
1314
1315 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1316
Jeremy Johnsona0150012023-11-15 15:52:06 +00001317 tens_ser_list = []
1318 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001319 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1320 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001321 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001322 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1323 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001324 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001325 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001326 # ERROR_IF or floating point test
1327 return TosaTensorValuesGen.tvgLazyGenDefault(
1328 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001329 )
1330
1331 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001332 def tvgReduceSum(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001333 dtype = dtypeList[0]
1334 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001335 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001336 pCount, cCount = op["operands"]
1337 assert (
1338 pCount == 1 and cCount == 0
1339 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1340 # Limit values so that the sum cannot exceed the range of an int32 during
1341 # summation of any axis
1342 range_val = int((1 << 31) / max(shapeList[0]))
1343 values_arr = np.int32(
1344 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1345 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001346 tens_ser_list = []
1347 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001348 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001349 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001350 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001351 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001352 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001353 if (
1354 error_name is None
1355 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1356 ):
1357 # Limit ranges for (non error & non compliance) tests by using
1358 # values that can be summed on any axis to not hit infinity
1359 highval_lookup = {
1360 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1361 / max(shapeList[0])
1362 }
1363 data_range = TosaTensorValuesGen._get_data_range(
1364 testGen, dtype, highval_lookup
1365 )
1366 assert data_range is not None
1367 argsDict["data_range"] = data_range
1368
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001369 return TosaTensorValuesGen.tvgLazyGenDefault(
1370 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001371 )
1372
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001373 @staticmethod
1374 def tvgReduceProduct(
1375 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1376 ):
1377 dtype = dtypeList[0]
1378 if error_name is None:
1379 # Limit ranges for (non error) tests by using
1380 # values that can be multiplied on any axis to not hit infinity
1381 highval_lookup = {
1382 dtype: math.pow(
1383 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1384 1 / max(shapeList[0]),
1385 )
1386 }
1387 data_range = TosaTensorValuesGen._get_data_range(
1388 testGen, dtype, highval_lookup
1389 )
1390 assert data_range is not None
1391 argsDict["data_range"] = data_range
1392
1393 return TosaTensorValuesGen.tvgLazyGenDefault(
1394 testGen, opName, dtypeList, shapeList, argsDict, error_name
1395 )
1396
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00001397 @staticmethod
1398 def tvgResize(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1399 data_range = TosaTensorValuesGen._get_data_range(
1400 testGen,
1401 dtypeList[0],
1402 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1403 )
1404 if data_range:
1405 argsDict["data_range"] = data_range
1406 # Needed for compliance
1407 argsDict["max_abs_value"] = data_range[1]
1408
1409 return TosaTensorValuesGen.tvgLazyGenDefault(
1410 testGen, opName, dtypeList, shapeList, argsDict, error_name
1411 )
1412
Jeremy Johnson30476252023-11-20 16:15:30 +00001413 # Set the POW exponent high data range
1414 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1415 DType.FP32: 10.0,
1416 DType.FP16: 10.0,
1417 DType.BF16: 10.0,
1418 }
1419 # POW highest base value (within a safe margin of error) that can be raised
1420 # to +ve exponent that doesn't become Infinity
1421 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1422 DType.FP32: math.floor(
1423 math.pow(
1424 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1425 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1426 )
1427 ),
1428 DType.FP16: math.floor(
1429 math.pow(
1430 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1431 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1432 )
1433 ),
1434 DType.BF16: math.floor(
1435 math.pow(
1436 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1437 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1438 )
1439 ),
1440 }
1441 # POW lowest base value (within a safe margin of error) that can be raised
1442 # to -ve exponent that doesn't become Infinity
1443 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1444 DType.FP32: math.ceil(
1445 math.pow(
1446 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1447 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1448 )
1449 * 1000
1450 )
1451 / 1000,
1452 DType.FP16: math.ceil(
1453 math.pow(
1454 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1455 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1456 )
1457 * 1000
1458 )
1459 / 1000,
1460 DType.BF16: math.ceil(
1461 math.pow(
1462 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1463 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1464 )
1465 * 1000
1466 )
1467 / 1000,
1468 }
1469
1470 @staticmethod
1471 def tvgPow(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1472 if error_name is not None:
1473 return TosaTensorValuesGen.tvgLazyGenDefault(
1474 testGen, opName, dtypeList, shapeList, argsDict, error_name
1475 )
1476 dtype = dtypeList[0]
1477 # Different ranges for POW
1478 test_set = argsDict["s"]
1479 if test_set == 0:
1480 # Positive base with fractional exponent
1481 base_range = TosaTensorValuesGen._get_data_range(
1482 testGen,
1483 dtype,
1484 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1485 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1486 )
1487 exp_range = TosaTensorValuesGen._get_data_range(
1488 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1489 )
1490 exp_round = False
1491 else:
1492 # Integer exponent
1493 exp_range = TosaTensorValuesGen._get_data_range(
1494 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1495 )
1496 exp_round = True
1497 if test_set == 1:
1498 # Positive base
1499 base_range = TosaTensorValuesGen._get_data_range(
1500 testGen,
1501 dtype,
1502 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1503 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1504 )
1505 else:
1506 assert test_set == 2
1507 # Negative base
1508 # Supply new look up tables with negative values
1509 base_range = TosaTensorValuesGen._get_data_range(
1510 testGen,
1511 dtype,
1512 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1513 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1514 )
1515
1516 data_range_list = (
1517 {
1518 "range": base_range,
1519 },
1520 {
1521 "range": exp_range,
1522 "round": exp_round,
1523 },
1524 )
1525 argsDict["data_range_list"] = data_range_list
1526 return TosaTensorValuesGen.tvgLazyGenDefault(
1527 testGen, opName, dtypeList, shapeList, argsDict, error_name
1528 )
1529
1530 @staticmethod
1531 def tvgLogRsqrt(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1532 # LOG & RSQRT data range from lowest expressible positive number to
1533 # largest to avoid NaNs
1534 data_range = TosaTensorValuesGen._get_data_range(
1535 testGen,
1536 dtypeList[0],
1537 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1538 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1539 )
1540 if data_range:
1541 argsDict["data_range"] = data_range
1542
1543 return TosaTensorValuesGen.tvgLazyGenDefault(
1544 testGen, opName, dtypeList, shapeList, argsDict, error_name
1545 )
1546
1547 # Set the EXP data range to the log of the largest to smallest values
1548 # to avoid infinities or making the result zero
1549 TVG_FLOAT_HIGH_VALUE_EXP = {
1550 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1551 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1552 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1553 }
1554 TVG_FLOAT_LOW_VALUE_EXP = {
1555 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1556 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1557 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1558 }
1559
1560 @staticmethod
1561 def tvgExp(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1562 data_range = TosaTensorValuesGen._get_data_range(
1563 testGen,
1564 dtypeList[0],
1565 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1566 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1567 )
1568 if data_range:
1569 argsDict["data_range"] = data_range
1570
1571 return TosaTensorValuesGen.tvgLazyGenDefault(
1572 testGen, opName, dtypeList, shapeList, argsDict, error_name
1573 )
1574
1575 @staticmethod
1576 def tvgFullyConnected(
1577 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1578 ):
1579 dtype = dtypeList[0]
1580 if (
1581 error_name is None
1582 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001583 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001584 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001585 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001586 # Limit ranges for (non error & non compliance) FP tests by using
1587 # values that can be multiplied on any axis to not hit infinity/NaN
1588 IC = shapeList[0][1]
1589 highval_lookup = {
1590 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1591 }
1592 data_range = TosaTensorValuesGen._get_data_range(
1593 testGen, dtype, highval_lookup
1594 )
1595 assert data_range is not None
1596 argsDict["data_range"] = data_range
1597
1598 return TosaTensorValuesGen.tvgLazyGenDefault(
1599 testGen, opName, dtypeList, shapeList, argsDict, error_name
1600 )
1601
Jeremy Johnson708da822023-11-15 16:25:45 +00001602 @staticmethod
1603 def tvgCast(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1604 in_dtype = dtypeList[0]
1605 out_dtype = argsDict["out_type"]
1606 # Create look up to limit input tensor to output type maximums to avoid
1607 # FP infinities and saturation of integers
1608 out_range = testGen.getDTypeRange(out_dtype, high_inclusive=True)
1609 highval_lookup = {in_dtype: out_range[1]}
1610 data_range = TosaTensorValuesGen._get_data_range(
1611 testGen,
1612 in_dtype,
1613 highval_lookup,
1614 )
1615
1616 assert data_range is not None
1617 argsDict["data_range"] = data_range
1618
1619 return TosaTensorValuesGen.tvgLazyGenDefault(
1620 testGen, opName, dtypeList, shapeList, argsDict, error_name
1621 )
1622
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001623 @staticmethod
1624 def tvgGather(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1625 K = shapeList[0][1]
1626
1627 # Fix the type of the indices tensor
1628 dtypeList[1] = DType.INT32
1629
1630 dtype = dtypeList[0]
1631 if not gtu.dtypeIsSupportedByCompliance(dtype):
1632 # Test unsupported by data generator
1633 op = testGen.TOSA_OP_LIST[opName]
1634 pCount, cCount = op["operands"]
1635 assert (
1636 pCount == 2 and cCount == 0
1637 ), "Op.GATHER must have 2 placeholders, 0 consts"
1638
1639 tens_ser_list = []
1640 for idx, shape in enumerate(shapeList):
1641 dtype = dtypeList[idx]
1642 if idx != 1:
1643 arr = testGen.getRandTensor(shape, dtype)
1644 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1645 else:
1646 # Limit data range of indices tensor upto K (exclusive)
1647 arr = testGen.getRandTensor(shape, dtype, (0, K))
1648 # To match old functionality - create indices as CONST
1649 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1650
1651 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1652
1653 else:
1654 # ERROR_IF or floating point test
1655 # Use inclusive values upto index K for indices tensor
1656 data_range_list = (
1657 {"range": None},
1658 {"range": (0, K - 1)},
1659 )
1660 argsDict["data_range_list"] = data_range_list
1661
1662 return TosaTensorValuesGen.tvgLazyGenDefault(
1663 testGen, opName, dtypeList, shapeList, argsDict, error_name
1664 )
1665
1666 @staticmethod
1667 def tvgScatter(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1668 K = shapeList[0][1]
1669 W = shapeList[2][1]
1670
1671 # Work out an indices tensor here with data that doesn't exceed the
1672 # dimension K of the values_in tensor and does NOT repeat the same K
1673 # location as needed by the spec:
1674 # "It is not permitted to repeat the same output index within a single
1675 # SCATTER operation and so each output index occurs at most once."
1676 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1677
1678 # Fix the type of the indices tensor
1679 dtypeList[1] = DType.INT32
1680
1681 dtype = dtypeList[0]
1682 if not gtu.dtypeIsSupportedByCompliance(dtype):
1683 # Test unsupported by data generator
1684 op = testGen.TOSA_OP_LIST[opName]
1685 pCount, cCount = op["operands"]
1686 assert (
1687 pCount == 3 and cCount == 0
1688 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1689
1690 tens_ser_list = []
1691 for idx, shape in enumerate(shapeList):
1692 dtype = dtypeList[idx]
1693 if idx != 1:
1694 arr = testGen.getRandTensor(shape, dtype)
1695 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1696 else:
1697 # Create the indices array
1698 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1699 arr = []
1700 for n in range(shape[0]):
1701 # Get a shuffled list of output indices (0 to K-1) and
1702 # limit length to W
1703 arr.append(testGen.rng.permutation(K)[:W])
1704 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1705 # To match old functionality - create indices as CONST
1706 tens_ser_list.append(
1707 testGen.ser.addConst(shape, dtype, indices_arr)
1708 )
1709
1710 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1711
1712 else:
1713 # ERROR_IF or floating point test
1714 # Use inclusive values upto index K for indices tensor
1715 data_range_list = (
1716 {"range": None},
1717 {"range": (0, K - 1)},
1718 {"range": None},
1719 )
1720 argsDict["data_range_list"] = data_range_list
1721
1722 return TosaTensorValuesGen.tvgLazyGenDefault(
1723 testGen, opName, dtypeList, shapeList, argsDict, error_name
1724 )
1725
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001726
1727class TosaArgGen:
1728 """Argument generators create exhaustive or random lists of attributes for
1729 operators that take attributes or other parameters.
1730
1731 The return value is a list of (descriptive_name, [arglist]) tuples where
1732 the descriptive_name is appended to the test name and the arglist is expanded
1733 as arguments to the operator build function.
1734 """
1735
1736 def __init__(self):
1737 pass
1738
1739 @staticmethod
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001740 def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001741 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001742 if (
1743 error_name is None
1744 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1745 and gtu.dtypeIsSupportedByCompliance(dtype)
1746 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001747 if dtype in [DType.FP16, DType.FP32, DType.BF16]:
1748 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1749 else:
1750 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1751 else:
1752 # Error test or No data generator types listed - assume random
1753 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1754
1755 # Expand arg list with other data generator types
1756 new_arg_list = []
1757 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001758 for arg_str, args_dict in arg_list:
1759 args_dict["dg_type"] = dg_type
Jeremy Johnson1271c442023-09-05 11:39:26 +01001760 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001761 if error_name is None:
1762 num_test_sets = (
1763 args_dict["num_test_sets"]
1764 if "num_test_sets" in args_dict
1765 else 0
1766 )
1767 else:
1768 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001769
1770 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1771 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001772 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001773 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001774 shape_info = (
1775 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1776 if "shape" in args_dict
1777 else ""
1778 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001779 print(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001780 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001781 )
1782 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001783 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001784 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001785 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001786
Jeremy Johnson30476252023-11-20 16:15:30 +00001787 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1788
1789 if num_test_sets > 0:
1790 for s in range(0, num_test_sets):
1791 new_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001792 new_args_dict = args_dict.copy()
1793 new_args_dict["s"] = s
1794 new_arg_list.append((new_arg_str, new_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001795 else:
1796 # Default is a single test
1797 new_arg_list.append((arg_str, args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001798
1799 return new_arg_list
1800
1801 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001802 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1803 """A trivial argument generator for operators that don't take any
1804 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001805 arg_list = TosaArgGen._add_data_generators(
1806 testGen,
1807 opName,
1808 dtype,
1809 [("", {})],
1810 error_name,
1811 )
1812 # Return list of tuples: (arg_str, args_dict)
1813 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001814
1815 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +00001816 def agPow(testGen, opName, shapeList, dtype, error_name=None):
1817 """Pow operator needs different test sets to cover random numbers
1818 without creating NaNs or Infs"""
1819 arg_list = TosaArgGen._add_data_generators(
1820 testGen,
1821 opName,
1822 dtype,
1823 [("", {"num_test_sets": 3})],
1824 error_name,
1825 )
1826 # Return list of tuples: (arg_str, args_dict)
1827 return arg_list
1828
1829 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001830 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1831 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001832 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001833 shape = shapeList[0]
1834
1835 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001836 # Set too small axis
1837 axes = [testGen.rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001838 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001839 # Set too large axis
1840 axes = [testGen.rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001841 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001842 # Create tests for each dimension
1843 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001844
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001845 opid = testGen.TOSA_OP_LIST[opName]["op"]
1846
1847 for a in axes:
1848 args_dict = {"axis": int(a)}
1849 if opid == Op.REDUCE_SUM:
1850 args_dict["dot_products"] = gtu.product(shape)
1851 args_dict["shape"] = shape
1852 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1853 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1854
1855 arg_list.append(("axis{}".format(a), args_dict))
1856
1857 arg_list = TosaArgGen._add_data_generators(
1858 testGen,
1859 opName,
1860 dtype,
1861 arg_list,
1862 error_name,
1863 )
1864 # Return list of tuples: (arg_str, args_dict)
1865 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001866
1867 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001868 def _calculate_sparsity(num_tests, sparsity_factor):
1869 sparsity = num_tests // sparsity_factor + 1
1870 # If there are only a small number of tests, just select them all
1871 if sparsity < 13:
1872 sparsity = 1
1873 # To get a variety of parameter combinations sparsity should not be a
1874 # multiple of 2, 3 or 5
1875 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1876 sparsity += 1
1877 return sparsity
1878
1879 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001880 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001881 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001882 arg_list = []
1883
Jeremy Johnson0c716862023-04-13 17:18:19 +01001884 if testGen.args.level8k and error_name is not None:
1885 # Don't produce negative large tests
1886 return arg_list
1887
1888 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001889 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001890 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001891 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001892
Jeremy Johnson1271c442023-09-05 11:39:26 +01001893 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001894
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001895 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01001896 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001897 depthwise = opName.startswith("depthwise")
1898
1899 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001900 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001901 if error_name != ErrorIf.WrongRank:
1902 assert len(ifm_shape) == rank
1903 assert len(filter_shape) == rank
1904
Jeremy Johnson0c716862023-04-13 17:18:19 +01001905 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001906 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001907 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01001908 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001909 # compliance size - KS
1910 k_size = gtu.product(k_shape)
1911 if not depthwise:
1912 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001913
Jeremy Johnson0c716862023-04-13 17:18:19 +01001914 if not testGen.args.level8k:
1915 # Generate comprehensive argument lists
1916 # - except for named errors, which use specific invalid value(s)
1917 if error_name == ErrorIf.PadSmallerZero:
1918 p_vals = [testGen.rng.choice(range(-5, 0))]
1919 else:
1920 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1921 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1922 if error_name == ErrorIf.StrideSmallerOne:
1923 # Can't use stride=0, as it is used to derive output shape, as a divisor
1924 s_vals = [testGen.rng.choice(range(-5, 0))]
1925 else:
1926 # Stride must be greater than 1 to force non-integer error
1927 startStride = (
1928 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001929 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001930 s_vals = [
1931 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1932 ]
1933 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1934 if error_name == ErrorIf.DilationSmallerOne:
1935 d_vals = [testGen.rng.choice(range(-5, 1))]
1936 else:
1937 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1938 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001939
Jeremy Johnson0c716862023-04-13 17:18:19 +01001940 if not error_name and testGen.args.oversize:
1941 # add some oversize argument values
1942 if max(ifm_shape) < 64:
1943 bigPadding = 9
1944 paddings.update(
1945 {
1946 x
1947 for x in itertools.product(
1948 *([[0, bigPadding]] * (k_rank * 2))
1949 )
1950 }
1951 )
1952 bigStride = 8
1953 strides.update(
1954 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1955 )
1956 bigDilation = 7
1957 dilations.update(
1958 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1959 )
1960 max_dim_size = None
1961
1962 # There are too many parameter combinations, so generate them sparsely,
1963 # very sparse for negative tests
1964 sparsity_factor = 2 if error_name else 120
1965 sparsity = TosaArgGen._calculate_sparsity(
1966 len(paddings) * len(strides) * len(dilations), sparsity_factor
1967 )
1968 else:
1969 # Only test 8k levels boundaries
1970 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1971 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1972 bigPadding = bigKernel
1973
1974 dilation_shape = [1] * k_rank
1975 pad_shape = [0] * k_rank * 2
1976 if conv3d:
1977 # Small stride apart from for big kernel (see below) to keep
1978 # tensor size/calculation small
1979 stride_shape = [1] * k_rank
1980 for idx in range(k_rank):
1981 pad_offset = idx * 2
1982 if k_shape[idx] == bigKernel:
1983 # Padding shape needs to account for tensor shape
1984 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1985 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1986 # Big stride to reduce output size
1987 stride_shape[idx] = bigKernel
1988 else:
1989 # Account for kernel size
1990 pad_shape[pad_offset] = k_shape[idx] - 1
1991 else:
1992 # Always have a large stride with extra padding and dilation to keep
1993 # tensor calculation reasonable
1994 stride_shape = [bigKernel] * k_rank
1995 for idx in range(k_rank):
1996 # Dilation shape must account for kernel size
1997 dilation_shape[idx] = bigKernel // k_shape[idx]
1998 # Padding shape needs to accommodate tensor/kernel & dilation
1999 pad_offset = idx * 2
2000 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
2001 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
2002
2003 strides = {tuple(stride_shape)}
2004 dilations = {tuple(dilation_shape)}
2005 paddings = {tuple(pad_shape)}
2006 # Create a limit for the output dimensions size
2007 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2008
2009 # Currently allow all combinations that are reasonable size
2010 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002011
2012 n = 0
2013 for s in sorted(list(strides)):
2014 for p in sorted(list(paddings)):
2015 for d in sorted(list(dilations)):
2016 if (
2017 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01002018 # the padded shape must exceed the dilation * kernel to get a positive
2019 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002020 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
2021 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01002022 and (
2023 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01002024 or (
2025 (ifm_shape[3] - 1 + p[4] + p[5])
2026 > d[2] * (k_shape[2] - 1)
2027 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01002028 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002029 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002030 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01002031 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002032 for index in range(k_rank):
2033 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002034 partial = (
2035 ifm_shape[index + 1]
2036 - 1
2037 + p[pad_offset]
2038 + p[pad_offset + 1]
2039 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002040 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002041 remainders.append(partial % s[index])
2042 outputs.append((partial // s[index]) + 1)
2043
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002044 if (
2045 # the parameters must produce integer exact output
2046 error_name != ErrorIf.ConvOutputShapeNonInteger
2047 and max(remainders) == 0
2048 ) or (
2049 error_name == ErrorIf.ConvOutputShapeNonInteger
2050 and max(remainders) > 0
2051 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002052 if (
2053 max_dim_size is not None
2054 and max(outputs) >= max_dim_size
2055 ):
2056 # Test will consume too much memory - skip it
2057 continue
2058
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002059 # Compliance - number of dot product calculations
2060 if depthwise:
Jeremy Johnson4f931302024-01-04 17:05:24 +00002061 # N*OH*OW*C*M
2062 dots = gtu.product(
2063 (ifm_shape[0], *outputs, *filter_shape[2:])
2064 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002065 else:
Jeremy Johnson4f931302024-01-04 17:05:24 +00002066 # N*OH*OW*OC or N*OD*OH*OW*OC
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002067 dots = gtu.product(
2068 (ifm_shape[0], *outputs, filter_shape[0])
2069 )
2070 args_dict = {
2071 "acc_type": accum_dtype,
2072 "stride": s,
2073 "pad": p,
2074 "dilation": d,
2075 "kernel": k_shape,
2076 "ks": k_size,
2077 "dot_products": dots,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002078 "shape": ifm_shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002079 }
2080
Jeremy Johnson0c716862023-04-13 17:18:19 +01002081 # Support for larger values than 9 needs different delimiter
2082 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002083 arg_list.append(
2084 (
James Ward8b390432022-08-12 20:48:56 +01002085 "acc{}_st{}_pad{}_dilat{}".format(
2086 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002087 delim.join([str(x) for x in s]),
2088 delim.join([str(x) for x in p]),
2089 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002090 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002091 args_dict,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002092 )
2093 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002094 n += 1
2095
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002096 arg_list = TosaArgGen._add_data_generators(
2097 testGen,
2098 opName,
2099 dtypes[0],
2100 arg_list,
2101 error_name,
2102 )
2103 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002104 return arg_list
2105
2106 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01002107 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
2108
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002109 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002110 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002111
2112 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002113 accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002114 elif error_name == ErrorIf.WrongInputType:
2115 # Pick some potentially correct output dtype if input type is incorrect
2116 accum_dtype = DType.INT32
2117 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002118 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002119
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002120 # Set up compliance info
2121 args_dict = {
2122 "acc_type": accum_dtype,
2123 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2124 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2125 "shape": shapeList[0],
2126 }
2127
2128 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2129
2130 arg_list = TosaArgGen._add_data_generators(
2131 testGen,
2132 opName,
2133 input_dtype,
2134 arg_list,
2135 error_name,
2136 )
2137 # Return list of tuples: (arg_str, args_dict)
2138 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002139
2140 @staticmethod
2141 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
2142 # Get valid accumulate type(s)
2143 if dtype == DType.INT8:
2144 accum_dtypes = [DType.INT32]
2145 elif dtype == DType.INT16:
2146 accum_dtypes = [DType.INT48]
2147 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002148 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002149 elif dtype == DType.BF16:
2150 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002151 elif dtype == DType.FP32:
2152 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01002153 elif error_name is None:
2154 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2155
2156 if error_name == ErrorIf.WrongOutputType:
2157 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson1271c442023-09-05 11:39:26 +01002158 accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002159 elif error_name == ErrorIf.WrongInputType:
2160 # Pick some potentially correct output dtype if input type is incorrect
2161 accum_dtypes = [DType.INT32]
2162
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002163 # Set up compliance info
2164 args_dict = {
2165 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2166 # Set dot_products = N*H*W
2167 "dot_products": gtu.product(
2168 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2169 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002170 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002171 }
2172
2173 # Create arg tuple of string and dict
2174 arg_list = []
2175 for a in accum_dtypes:
2176 d = args_dict.copy()
2177 d["acc_type"] = a
2178 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002179
2180 arg_list = TosaArgGen._add_data_generators(
2181 testGen,
2182 opName,
2183 dtype,
2184 arg_list,
2185 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002186 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002187 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002188 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002189
2190 @staticmethod
2191 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002192 arg_list = []
2193
Jeremy Johnson0c716862023-04-13 17:18:19 +01002194 if testGen.args.level8k and error_name is not None:
2195 # Don't produce negative large tests
2196 return arg_list
2197
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002198 ifm_shape = shapeList[0]
2199 filter_shape = shapeList[1]
2200
Jeremy Johnson1271c442023-09-05 11:39:26 +01002201 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002202
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002203 # Must be rank 4
2204 if error_name != ErrorIf.WrongRank:
2205 assert len(ifm_shape) == 4
2206 assert len(filter_shape) == 4
2207
Jeremy Johnson0c716862023-04-13 17:18:19 +01002208 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson95a67102024-01-10 14:16:39 +00002209 # compliance size - KS
2210 k_size = gtu.product((*k_shape, ifm_shape[3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002211
Jeremy Johnson0c716862023-04-13 17:18:19 +01002212 if not testGen.args.level8k:
2213 # Generate comprehensive argument lists
2214 # - except for named errors, which use specific invalid value(s)
2215 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2216 if error_name == ErrorIf.PadLargerEqualKernel:
2217 max_filter_size = -max(k_shape[0], k_shape[1])
2218 p_vals = [
2219 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
2220 ]
2221 else:
2222 p_vals = [
2223 x
2224 for x in range(
2225 smallest_padding_size, testGen.args.max_conv_padding + 1
2226 )
2227 ]
2228 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2229 if error_name == ErrorIf.StrideSmallerOne:
2230 # Can't use stride=0, as it is used to derive output shape, as a divisor
2231 s_vals = [testGen.rng.choice(range(-5, 0))]
2232 else:
2233 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2234 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002235
Jeremy Johnson0c716862023-04-13 17:18:19 +01002236 if not error_name and testGen.args.oversize:
2237 # add some oversize argument values
2238 if max(ifm_shape) < 64:
2239 bigPadding = 9
2240 paddings.update(
2241 {
2242 x
2243 for x in itertools.product(
2244 *([[smallest_padding_size, bigPadding]] * 4)
2245 )
2246 }
2247 )
2248 bigStride = 8
2249 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2250
2251 # There are too many parameter combinations, so generate them sparsely,
2252 # very sparse for negative tests
2253 sparsity_factor = 2 if error_name else 10
2254 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2255 # If there are only a small number of tests, just select them all
2256 if sparsity < 13:
2257 sparsity = 1
2258 # To get a variety of parameter combinations sparsity should not be a
2259 # multiple of 2, 3 or 5
2260 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2261 sparsity += 1
2262 else:
2263 # Only test 8k levels boundaries
2264 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2265 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2266 bigPadding = bigKernel
2267
2268 pad_shape = [0] * (len(k_shape) * 2)
2269 stride_shape = [1] * len(k_shape)
2270 # The point at which input dimension combined with the stride will
2271 # create large output sizes!
2272 LARGE_SIZE = 2
2273 for idx in range(len(k_shape)):
2274 pad_offset = idx * 2
2275 if k_shape[idx] == bigKernel:
2276 # Set large stride
2277 stride_shape[idx] = bigKernel
2278 # Use negative output padding to reduce shape size
2279 pad_shape[pad_offset] = -(bigPadding - 1)
2280 if ifm_shape[idx + 1] > LARGE_SIZE:
2281 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2282 else:
2283 # The other dimension should be the bigKernel
2284 alt_idx = 1 - idx
2285 if (
2286 k_shape[alt_idx] == bigKernel
2287 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2288 ):
2289 # As the input is small, the large stride won't
2290 # affect the output so we can add some padding
2291 pad_shape[pad_offset + 1] = bigPadding
2292
2293 strides = {tuple(stride_shape)}
2294 paddings = {tuple(pad_shape)}
2295
2296 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002297 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002298
2299 n = 0
2300 for s in sorted(list(strides)):
2301 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07002302 if n % sparsity == 0:
2303 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002304 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2305 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07002306 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002307
Jeremy Johnson95a67102024-01-10 14:16:39 +00002308 # N*OH*OW*OC
2309 dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
2310 args_dict = {
2311 "acc_type": accum_dtype,
2312 "stride": s,
2313 "pad": p,
2314 "kernel": k_shape,
2315 "ks": k_size,
2316 "dot_products": dots,
2317 "shape": ifm_shape,
2318 "out_shape": os,
2319 }
2320
Jeremy Johnson0c716862023-04-13 17:18:19 +01002321 # Support for larger values than 9 needs different delimiter
2322 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07002323 arg_list.append(
2324 (
James Ward8b390432022-08-12 20:48:56 +01002325 "acc{}_st{}_pad{}_os{}".format(
2326 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002327 delim.join([str(x) for x in s]),
2328 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07002329 "x".join([str(x) for x in os]),
2330 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00002331 args_dict,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002332 )
TatWai Chong24594f52022-06-08 00:48:04 -07002333 )
2334 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002335
Jeremy Johnson95a67102024-01-10 14:16:39 +00002336 arg_list = TosaArgGen._add_data_generators(
2337 testGen,
2338 opName,
2339 dtypes[0],
2340 arg_list,
2341 error_name,
2342 )
2343 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002344 return arg_list
2345
2346 @staticmethod
2347 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002348 rank = len(shapeList[0])
2349
2350 # Exhaustively test combinations of padding on each side of each dimension
2351 # - the range of padding values is defined by pad_min and pad_max
2352 # - for padding >9, the name format needs to be more distinctive
2353 pad_min, pad_max = 0, 1
2354 pad_values = [x for x in range(pad_min, pad_max + 1)]
2355 if error_name == ErrorIf.PadSmallerZero:
2356 pad_values = [x for x in range(-2, 0)]
2357 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2358 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2359
2360 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
2361 pad_const_int = testGen.getRandNumberDType(dtype)
2362 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00002363 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002364 pad_const_int = 0
2365 pad_const_fp = testGen.getRandNumberDType(dtype)
2366 else:
2367 return []
2368
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002369 list_shape_pad_values = list(shape_pad_values)
2370 # If we are producing tests for rank 6 or greater use sparsity
2371 if len(list_shape_pad_values) > 1024:
2372 sparsity_factor = 2 if error_name else 120
2373 sparsity = TosaArgGen._calculate_sparsity(
2374 len(list_shape_pad_values), sparsity_factor
2375 )
2376 else:
2377 sparsity = 1
2378
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002379 # Build arg list
2380 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002381 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002382 paddings = list(paddings)
2383 args_valid = True
2384
2385 if error_name == ErrorIf.PadSmallerZero:
2386 # Prevent negative output shapes while ensuring still testing for negative padding
2387 for i in range(rank):
2388 dim_after_padding = (
2389 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2390 )
2391 if dim_after_padding < 1:
2392 paddings[i] = (0, 0)
2393 if all([p > -1 for p in paddings[i]]):
2394 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002395 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01002396 name = "pad"
2397 for r in range(rank):
2398 before, after = paddings[r]
2399 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002400 args_dict = {
2401 "pad": np.array(paddings),
2402 "pad_const_int": pad_const_int,
2403 "pad_const_fp": pad_const_fp,
2404 }
2405 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002406
2407 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
2408 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002409
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002410 arg_list = TosaArgGen._add_data_generators(
2411 testGen,
2412 opName,
2413 dtype,
2414 arg_list,
2415 error_name,
2416 )
2417
2418 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002419 return arg_list
2420
2421 @staticmethod
2422 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
2423 arg_list = []
2424
2425 shape = shapeList[0]
2426 if error_name != ErrorIf.WrongRank:
2427 assert len(shape) == 4
2428
Jeremy Johnson0c716862023-04-13 17:18:19 +01002429 test_level8k = testGen.args.level8k and error_name is None
2430
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002431 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002432 startKernel = 2
2433 startPad = 0
2434 if not test_level8k:
2435 # Generate comprehensive argument lists
2436 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2437 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2438 # Stride must be greater than 1 to force non-integer error
2439 s_vals = [
2440 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2441 ]
2442 strides = {x for x in itertools.product(*([s_vals] * 2))}
2443 k_vals = [
2444 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2445 ]
2446 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2447 max_dim_size = None
2448 else:
2449 # Only test 8k levels
2450 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2451 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2452 strides = {(1, bigStride), (bigStride, 4)}
2453 kernels = {(1, bigKernel), (bigKernel, 3)}
2454 paddings = set()
2455 for s in sorted(list(strides)):
2456 for k in sorted(list(kernels)):
2457 padding = []
2458 for idx in range(len(k)):
2459 total_padding = s[idx] - shape[idx + 1] + k[idx]
2460 while total_padding < 0:
2461 # Must meet: shape + padding > kernel
2462 total_padding += s[idx]
2463 if total_padding < k[idx]:
2464 padding.extend([0, total_padding])
2465 else:
2466 # Note this may produce padding >= k[idx] which is not
2467 # allowed - but will be ignored in the creation loop below
2468 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2469 paddings.add(tuple(padding))
2470 # Create a limit for the output dimensions size
2471 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002472
James Ward8b390432022-08-12 20:48:56 +01002473 if opName == "max_pool2d":
2474 accum_dtypes = [None] # max_pool has no accumulate dtype
2475 elif dtype == DType.INT8 or dtype == DType.INT16:
2476 accum_dtypes = [DType.INT32]
2477 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002478 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002479 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002480 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01002481 elif error_name is None:
2482 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2483 else:
2484 # Set to something for the ErrorIf case which has
2485 # incorrect input data-type
2486 accum_dtypes = [DType.INT32]
2487
Jeremy Johnson0c716862023-04-13 17:18:19 +01002488 if not test_level8k:
2489 if testGen.args.oversize:
2490 # add some oversize argument values
2491 bigStride = 7
2492 bigKernel = 9
2493 strides.update(
2494 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002495 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002496 kernels.update(
2497 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2498 )
2499 if max(shape) < 64:
2500 # padding must be less than the kernel size
2501 bigPadding = bigKernel - 1
2502 paddings.update(
2503 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2504 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002505
Jeremy Johnson0c716862023-04-13 17:18:19 +01002506 # There are too many parameter combinations, so generate them sparsely,
2507 # very sparse for negative tests
2508 sparsity_factor = 2 if error_name else 500
2509 sparsity = (
2510 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2511 )
2512 else:
2513 # We have already limited test output combinations for 8k tests
2514 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002515
James Ward8b390432022-08-12 20:48:56 +01002516 arg_str = (
2517 "acc{}_st{}_kern{}_pad{}"
2518 if accum_dtypes[0] is not None
2519 else "st{}_kern{}_pad{}"
2520 )
2521
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002522 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002523 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002524 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002525
2526 # Support for larger values than 9 needs different delimiter
2527 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002528 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002529 delim.join([str(x) for x in stride]),
2530 delim.join([str(x) for x in kern]),
2531 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002532 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002533 args_dict = {
2534 "stride": stride,
2535 "pad": pad,
2536 "kernel": kern,
2537 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002538 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002539 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2540 }
James Ward8b390432022-08-12 20:48:56 +01002541
2542 if accum is not None:
2543 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002544 args_dict["acc_type"] = accum
2545 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002546
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002547 n = 0
James Ward8b390432022-08-12 20:48:56 +01002548 for a in accum_dtypes:
2549 for s in sorted(list(strides)):
2550 for p in sorted(list(paddings)):
2551 for k in sorted(list(kernels)):
2552 if error_name in [
2553 ErrorIf.StrideSmallerOne,
2554 ErrorIf.KernelSmallerOne,
2555 ErrorIf.PadSmallerZero,
2556 ErrorIf.PadLargerEqualKernel,
2557 ]:
2558 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
2559 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002560 )
James Ward8b390432022-08-12 20:48:56 +01002561 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002562 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002563 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002564 )
James Ward8b390432022-08-12 20:48:56 +01002565 elif (
2566 n % sparsity == 0
2567 # padding must not exceed the kernel size
2568 and p[0] < k[0]
2569 and p[1] < k[0]
2570 and p[2] < k[1]
2571 and p[3] < k[1]
2572 # the padded shape must exceed the kernel size
2573 and (shape[1] + p[0] + p[1]) > k[0]
2574 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002575 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002576 partial_h = shape[1] + p[0] + p[1] - k[0]
2577 partial_w = shape[2] + p[2] + p[3] - k[1]
2578 remainder_h = partial_h % s[0]
2579 remainder_w = partial_w % s[1]
2580 output_h = partial_h // s[0] + 1
2581 output_w = partial_w // s[1] + 1
2582 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01002583 if (
2584 # the parameters must produce integer exact output
2585 error_name != ErrorIf.PoolingOutputShapeNonInteger
2586 and remainder_h == 0
2587 and remainder_w == 0
2588 ) or (
2589 error_name == ErrorIf.PoolingOutputShapeNonInteger
2590 and (remainder_h != 0 or remainder_w != 0)
2591 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002592 if (
2593 max_dim_size is not None
2594 and max(output_h, output_w) > max_dim_size
2595 ):
2596 # Test will consume too much memory - skip it
2597 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002598 # Dot products = N*OH*OW*C
2599 dp = gtu.product(
2600 (shape[0], output_h, output_w, shape[3])
2601 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002602 arg_list.append(
2603 get_arg_list_element(a, s, p, k, dp, shape)
2604 )
James Ward8b390432022-08-12 20:48:56 +01002605 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002606
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002607 # Now add data generator types
2608 arg_list = TosaArgGen._add_data_generators(
2609 testGen,
2610 opName,
2611 dtype,
2612 arg_list,
2613 error_name,
2614 )
2615
2616 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002617 return arg_list
2618
2619 @staticmethod
2620 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
2621 arg_list = []
2622
2623 # Enumerate the output types here
2624 if error_name == ErrorIf.WrongOutputType:
2625 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
2626 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002627 dtypeList = [
2628 DType.BOOL,
2629 DType.INT16,
2630 DType.INT32,
2631 DType.FP16,
2632 DType.BF16,
2633 DType.FP32,
2634 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002635 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002636 dtypeList = [
2637 DType.BOOL,
2638 DType.INT8,
2639 DType.INT32,
2640 DType.FP16,
2641 DType.BF16,
2642 DType.FP32,
2643 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002644 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002645 dtypeList = [
2646 DType.BOOL,
2647 DType.INT8,
2648 DType.INT16,
2649 DType.FP16,
2650 DType.BF16,
2651 DType.FP32,
2652 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002653 elif inDtype == DType.BOOL:
2654 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002655 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00002656 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002657 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00002658 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002659 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00002660 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002661 elif error_name == ErrorIf.WrongInputType:
2662 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002663 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002664 else:
2665 raise Exception("Unexpected input dtype: {}".format(inDtype))
2666
2667 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002668 arg_list.append(
2669 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2670 )
2671
2672 # Now add data generator types
2673 arg_list = TosaArgGen._add_data_generators(
2674 testGen,
2675 opName,
2676 dtype,
2677 arg_list,
2678 error_name,
2679 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002680
2681 return arg_list
2682
2683 @staticmethod
2684 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
2685 arg_list = []
2686
2687 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002688 for outDtype in [
2689 DType.UINT8,
2690 DType.INT8,
2691 DType.INT16,
2692 DType.INT32,
2693 DType.UINT16,
2694 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002695 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002696 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002697 and error_name == ErrorIf.OutputZeroPointNotZero
2698 ):
2699 continue
2700 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002701 outDtype != DType.UINT16
2702 and error_name == ErrorIf.U16OutputZeroPointNotValid
2703 ) or (
2704 inDtype != DType.UINT16
2705 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002706 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002707 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002708 continue
2709 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002710 inDtype == DType.UINT8
2711 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002712 and error_name != ErrorIf.WrongOutputType
2713 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002714 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2715 continue
2716 if (
2717 inDtype not in [DType.INT8, DType.INT16]
2718 and outDtype == DType.UINT8
2719 and error_name != ErrorIf.WrongOutputType
2720 ):
2721 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2722 continue
2723 if (
2724 inDtype == DType.UINT16
2725 and outDtype != DType.INT16
2726 and error_name != ErrorIf.WrongOutputType
2727 ):
2728 # The only output dtype for UINT16 is INT16, skip all others
2729 continue
2730 if (
2731 inDtype != DType.INT16
2732 and outDtype == DType.UINT16
2733 and error_name != ErrorIf.WrongOutputType
2734 ):
2735 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002736 continue
2737 if (
2738 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002739 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002740 ):
2741 continue
2742
2743 for scale32 in [False, True]:
2744 if error_name == ErrorIf.ScaleTrue and not scale32:
2745 continue
2746 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2747 continue
2748 for double_round in [False, True]:
2749 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2750 continue
2751 for per_channel in [False, True]:
2752
2753 if (
2754 inDtype == DType.INT48
2755 and scale32
2756 and error_name != ErrorIf.ScaleTrue
2757 ):
2758 # Illegal condition. Must be scale32=False
2759 continue
2760 if (
2761 double_round
2762 and not scale32
2763 and error_name != ErrorIf.ScaleNotTrue
2764 ):
2765 # Illegal condition. ERROR_IF(!scale32 && double_round)
2766 continue
2767
2768 arg_list.append(
2769 (
2770 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002771 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002772 int(scale32),
2773 int(double_round),
2774 int(per_channel),
2775 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002776 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002777 )
2778 )
2779
2780 return arg_list
2781
2782 @staticmethod
2783 def agMul(testGen, opName, shapeList, dtype, error_name=None):
2784 arg_list = []
2785
2786 if dtype is DType.INT32:
2787 for p in range(testGen.args.num_rand_permutations):
2788
2789 shift = testGen.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002790 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002791 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002792 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002793
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002794 arg_list = TosaArgGen._add_data_generators(
2795 testGen,
2796 opName,
2797 dtype,
2798 arg_list,
2799 error_name,
2800 )
2801 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002802 return arg_list
2803
2804 @staticmethod
2805 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
2806 arg_list = []
2807
2808 arg_list.append(("roundTrue", [True]))
2809 arg_list.append(("roundFalse", [False]))
2810
2811 return arg_list
2812
Luke Hutton57287132023-02-06 14:54:18 +00002813 @staticmethod
2814 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2815 arg_list = []
2816
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002817 shape = shapeList[0]
2818 dot_products = gtu.product(shape)
2819 ks = 2 * shape[1] * shape[2] # 2*H*W
2820 for inverse in (True, False):
2821 args_dict = {
2822 "dot_products": dot_products,
2823 "shape": shape,
2824 "ks": ks,
2825 "acc_type": dtype,
2826 "inverse": inverse,
2827 }
2828 arg_list.append((f"inverse{inverse}", args_dict))
Luke Hutton57287132023-02-06 14:54:18 +00002829
Jeremy Johnsonc8330812024-01-18 16:57:28 +00002830 arg_list = TosaArgGen._add_data_generators(
2831 testGen,
2832 opName,
2833 dtype,
2834 arg_list,
2835 error_name,
2836 )
2837 # Return list of tuples: (arg_str, args_dict)
Luke Hutton57287132023-02-06 14:54:18 +00002838 return arg_list
2839
Jeremy Johnson6f57e6e2024-01-30 16:10:50 +00002840 @staticmethod
2841 def agRFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2842 arg_list = []
2843
2844 shape = shapeList[0]
2845 dot_products = gtu.product(shape)
2846 ks = shape[1] * shape[2] # H*W
2847 args_dict = {
2848 "dot_products": dot_products,
2849 "shape": shape,
2850 "ks": ks,
2851 "acc_type": dtype,
2852 }
2853 arg_list.append(("", args_dict))
2854
2855 arg_list = TosaArgGen._add_data_generators(
2856 testGen,
2857 opName,
2858 dtype,
2859 arg_list,
2860 error_name,
2861 )
2862 # Return list of tuples: (arg_str, args_dict)
2863 return arg_list
2864
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002865 # Helper function for reshape. Gets some factors of a larger number.
2866 @staticmethod
2867 def getFactors(val, start=1):
2868 factors = []
2869
2870 for i in range(start, int(np.sqrt(val)) + 1):
2871 if (val % i) == 0:
2872 factors.append(i)
2873
2874 return factors
2875
2876 @staticmethod
2877 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
2878 arg_list = []
2879
2880 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002881 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002882 factors = TosaArgGen.getFactors(totalElements)
2883
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002884 # Find new shapes up to the number of permutations asked for
2885 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002886 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002887 # Rank from 1 to TOSA_TENSOR_MAX_RANK
2888 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002889 if len(factors) < newRank:
2890 continue
2891
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002892 # escape_counter limits the generation of new shapes to a reasonable time
2893 for escape_counter in range(100):
2894
2895 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002896 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002897 remainingElements = totalElements
2898 shuffledFactors = testGen.rng.permutation(factors)
2899 for i in range(1, newRank):
2900 # pick rank-1 factors
2901 newShape.append(shuffledFactors[0])
2902 remainingElements = remainingElements // shuffledFactors[0]
2903 shuffledFactors = testGen.rng.permutation(
2904 TosaArgGen.getFactors(remainingElements)
2905 )
2906 newShape.append(remainingElements)
2907
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002908 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002909 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002910 for name, args_dict in arg_list:
2911 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002912 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002913 break
2914
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002915 if not duplicate:
2916 outShape = "x".join([str(x) for x in newShape])
2917 arg_list.append(
2918 (
2919 "perm{}_rank{}_out{}".format(p, newRank, outShape),
2920 {"new_shape": newShape},
2921 )
2922 )
2923 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002924 break
2925
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002926 # Now add data generator types
2927 arg_list = TosaArgGen._add_data_generators(
2928 testGen,
2929 opName,
2930 dtype,
2931 arg_list,
2932 error_name,
2933 )
2934
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002935 return arg_list
2936
2937 @staticmethod
2938 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
2939 arg_list = []
2940
2941 ifm_shape = shapeList[0]
2942
2943 if error_name == ErrorIf.IndexOutsideBounds:
2944 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
2945 incorrect_small_index = range(-len(ifm_shape), 0)
2946 permutations = [p for p in itertools.permutations(incorrect_large_index)]
2947 permutations.extend(
2948 [p for p in itertools.permutations(incorrect_small_index)]
2949 )
2950 elif error_name == ErrorIf.IndexUsedTwice:
2951 # Create list with a duplicated index
2952 perm_range = list(range(len(ifm_shape)))
2953 index_choice = testGen.rng.choice(range(len(perm_range)))
2954 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
2955 permutations = [p for p in itertools.permutations(perm_range)]
2956
2957 else:
2958 # Get all permutations
2959 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
2960
2961 # Limit to possible permutations from shape dimension or argument setting
2962 limit = min(len(permutations), testGen.args.num_rand_permutations)
2963
2964 # Get random permutation generator that uses all permutations
2965 random_permutations = testGen.rng.permutation(permutations)
2966
2967 # Create list of required amount of permutations
2968 arg_list = [
evacha0198477222024-01-26 12:25:32 +00002969 ("perm{}".format(p), {"perms": random_permutations[p].tolist()})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002970 for p in range(limit)
2971 ]
evacha0198477222024-01-26 12:25:32 +00002972 # Now add data generator types
2973 arg_list = TosaArgGen._add_data_generators(
2974 testGen,
2975 opName,
2976 dtype,
2977 arg_list,
2978 error_name,
2979 )
2980 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002981 return arg_list
2982
2983 @staticmethod
2984 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
2985 arg_list = []
2986
2987 ifm_shape = shapeList[0]
2988 rank = len(ifm_shape)
2989
2990 for p in range(testGen.args.num_rand_permutations):
2991 start = []
2992 size = []
2993
2994 valid = True
2995
2996 for i in range(rank):
2997 if ifm_shape[i] > 1:
2998 start.append(testGen.randInt(0, ifm_shape[i]))
2999 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
3000
3001 # Invalid slice size?
3002 if size[i] == 0:
3003 valid = False
3004 else:
3005 start.append(0)
3006 size.append(1)
3007
3008 if valid:
3009 # If ERROR_IF test required then incorrect start, size will be returned
3010 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
3011 testGen, error_name, ifm_shape, start, size
3012 )
evacha017f7d4252024-01-24 12:08:09 +00003013 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
3014 # Now add data generator types
3015 arg_list = TosaArgGen._add_data_generators(
3016 testGen,
3017 opName,
3018 dtype,
3019 arg_list,
3020 error_name,
3021 )
3022 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003023 return arg_list
3024
3025 @staticmethod
3026 def agTile(testGen, opName, shapeList, dtype, error_name=None):
3027 arg_list = []
3028
3029 ifm_shape = shapeList[0]
3030 rank = len(ifm_shape)
3031
3032 for p in range(testGen.args.num_rand_permutations):
3033
3034 # Pick a few random, but small multiple values
3035 # because otherwise this has a tendency to generate
3036 # enormous tensors
3037 multiples = []
3038 for i in range(rank):
3039 if ifm_shape[i] > 1000:
3040 # Multiple of 1 if ifm_shape dimension is large to reduce
3041 # tensor size
3042 multiples.append(1)
3043 elif max(ifm_shape) > 1000:
3044 multiples.append(2)
3045 else:
3046 multiples.append(testGen.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003047 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003048
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00003049 # Now add data generator types
3050 arg_list = TosaArgGen._add_data_generators(
3051 testGen,
3052 opName,
3053 dtype,
3054 arg_list,
3055 error_name,
3056 )
3057 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003058 return arg_list
3059
3060 @staticmethod
3061 def agResize(testGen, opName, shapeList, dtype, error_name=None):
3062 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003063 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003064
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003065 def get_aspect_ratio_resize_params():
3066 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
3067 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
3068 invert = testGen.rng.choice((False, True))
3069 letterbox = testGen.rng.choice((False, True))
3070
3071 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
3072 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
3073 scale_y_d = scale_x_d = 1
3074 offset_x = offset_y = 0
3075
3076 if letterbox:
3077 max_border = scale_y_n
3078 border_y = testGen.randInt(low=0, high=max_border)
3079 border_x = 0
3080 else:
3081 # Pillarboxing
3082 border_y = 0
3083 max_border = scale_x_n
3084 border_x = testGen.randInt(low=0, high=max_border)
3085
3086 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3087 offset = (offset_y, offset_x)
3088 border = (border_y, border_x)
3089
3090 return scale, offset, border
3091
3092 def get_upscale_downscale_params():
3093 valid_params = False
3094 while not valid_params:
3095 upscale = testGen.rng.choice((False, True))
3096
3097 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
3098 origin_sampling = testGen.rng.choice((False, True))
3099
3100 if upscale:
3101 shift = testGen.randInt(low=1, high=4)
3102 scale_x_d = scale_y_d = 1
3103 scale_x_n = scale_y_n = (
3104 1 << shift if origin_sampling else 2 << shift
3105 )
3106 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3107 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3108 else:
3109 scale_x_n = 1
3110 scale_y_n = 1
3111
3112 # Return list of valid scale_*_d values (max value 4) given input dim shape
3113 def get_valid_denom(ifm_dim):
3114 return [x for x in range(1, 5) if ifm_dim % x == 1]
3115
3116 # Generate list of valid downscale values and choose one randomly
3117 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3118 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3119
3120 if not valid_scale_y_ds and not valid_scale_x_ds:
3121 # Bad parameters, skip
3122 continue
3123
3124 if not valid_scale_y_ds:
3125 scale_y_d = 1
3126 else:
3127 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
3128
3129 if not valid_scale_x_ds:
3130 scale_x_d = 1
3131 else:
3132 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
3133
3134 border_x = border_y = 0
3135 offset_y = testGen.randInt(0, 16 * scale_y_n)
3136 offset_x = testGen.randInt(0, 16 * scale_x_n)
3137 valid_params = True
3138
3139 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3140 offset = (offset_y, offset_x)
3141 border = (border_y, border_x)
3142 return scale, offset, border
3143
3144 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003145 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3146 scale = scale_n / scale_d
3147 if scale > max_scale:
3148 factor = scale / max_scale
3149 new_scale_d = math.ceil(scale_d * factor)
3150 assert scale_n / new_scale_d <= max_scale
3151 scale_d = new_scale_d
3152 return scale_d
3153
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003154 # Scale
3155 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
3156 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
3157
3158 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
3159 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
3160
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003161 scale_y_d = fix_scale_to_max_scale(
3162 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3163 )
3164 scale_x_d = fix_scale_to_max_scale(
3165 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3166 )
3167
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003168 # Offsets and border within the scale
3169 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3170 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3171 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3172 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
3173
3174 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3175 offset = (offset_y, offset_x)
3176 border = (border_y, border_x)
3177 return scale, offset, border
3178
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003179 def get_level_8k_params():
3180 # Create 64x scale - 64/1 to 2048/32
3181 scale_d = testGen.randInt(
3182 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3183 )
3184 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3185 # Create half to fifth scaling
3186 scale_d_alt = testGen.randInt(low=2, high=6)
3187 scale_n_alt = 1
3188 switch = testGen.rng.choice((False, True))
3189 if switch:
3190 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3191 else:
3192 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3193
3194 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3195 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
3196 offset = (offset_y, offset_x)
3197 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
3198 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
3199 border = (border_y, border_x)
3200 return scale, offset, border
3201
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003202 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003203 # Exclude illegal {mode, type} configurations. Pick legal output types
3204 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3205 outputDTypeList = [DType.INT8]
3206 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3207 outputDTypeList = [DType.INT16]
3208 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3209 outputDTypeList = [DType.INT32]
3210 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3211 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003212 elif dtype == DType.FP16:
3213 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003214 elif dtype == DType.BF16:
3215 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003216 elif dtype == DType.FP32:
3217 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003218 elif error_name == ErrorIf.WrongInputType:
3219 # If an incorrect input type is used then we set a 'correct'
3220 # output type to avoid other errors
3221 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3222 else:
3223 continue
3224
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003225 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3226
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003227 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003228 perm = 0
3229 while perm < testGen.args.num_rand_permutations:
3230 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003231 if not testGen.args.level8k:
3232 _rnd_param_fn = testGen.rng.choice(
3233 (
3234 get_rand_params,
3235 get_upscale_downscale_params,
3236 get_aspect_ratio_resize_params,
3237 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003238 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003239 scale, offset, border = _rnd_param_fn()
3240 else:
3241 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003242
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003243 # Expand params for bounds-checking
3244 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3245 (offset_y, offset_x) = offset
3246 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003247
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003248 # Make sure output dimensions OH and OW are integers
3249 partial_output_y = (
3250 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3251 )
3252 partial_output_x = (
3253 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3254 )
3255 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003256 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003257 if (
3258 partial_output_y % scale_y_d == 0
3259 and partial_output_x % scale_x_d == 0
3260 ):
3261 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003262 if perm > 0:
3263 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003264 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003265 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003266 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003267 while partial_output_y % scale_y_d != 0:
3268 scale_y_d -= 1
3269 while partial_output_x % scale_x_d != 0:
3270 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003271 # Make sure we are still within max scaling
3272 if (
3273 scale_y_n / scale_y_d
3274 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3275 scale_x_n / scale_x_d
3276 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3277 # Skip the test as it is using too large a scaling factor
3278 if perm > 0:
3279 perm += 1
3280 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003281
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003282 output_y = partial_output_y // scale_y_d + 1
3283 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003284
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003285 if (
3286 output_y >= testGen.args.max_resize_output_dim
3287 or output_x >= testGen.args.max_resize_output_dim
3288 ) and error_name is None:
3289 # Skip positive test if output dim will be too high
3290 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003291 if not testGen.args.level8k or perm > 0:
3292 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003293 continue
3294
3295 if (
3296 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003297 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003298 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003299 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003300 ):
3301 # Output dimensions out of scope
3302 if error_name is not None and perm > 0:
3303 # As long as we have one ERROR_IF test, don't worry
3304 # about creating all the other permutations
3305 perm += 1
3306 continue
3307
3308 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3309 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003310 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003311 and output_y - scale_y_d < 1
3312 )
3313 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003314 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003315 and output_x - scale_x_d < 1
3316 )
3317 ):
3318 # Can't create a negative test with these params as it
3319 # will create invalid output size
3320 if perm > 0:
3321 perm += 1
3322 continue
3323
3324 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3325 offset = [offset_y, offset_x]
3326 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003327
3328 # Common for all data types
3329 if error_name is not None:
3330 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003331 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003332 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003333 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003334 outputDTypeNew,
3335 ) = TosaErrorIfArgGen.eiResizeErrorIf(
3336 testGen,
3337 error_name,
3338 mode,
3339 dtype,
3340 shapeList,
3341 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003342 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003343 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003344 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003345 )
3346 else:
3347 outputDTypeNew = outputDType
3348
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003349 arg_to_append = (
3350 arg_str.format(
3351 "N" if mode == ResizeMode.NEAREST else "B",
3352 testGen.typeStr(outputDTypeNew),
3353 scale[0],
3354 scale[1],
3355 scale[2],
3356 scale[3],
3357 offset[0],
3358 offset[1],
3359 border[0],
3360 border[1],
3361 ),
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003362 {
3363 "mode": mode,
3364 "scale": scale,
3365 "offset": offset,
3366 "border": border,
3367 "output_dtype": outputDTypeNew,
3368 },
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003369 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003370 if arg_to_append in arg_list:
3371 # Skip already generated test params
3372 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003373
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003374 # Valid permutation
3375 perm += 1
3376 arg_list.append(arg_to_append)
Jeremy Johnson32d0b5a2024-02-01 15:54:07 +00003377
3378 # Now add data generator types
3379 arg_list = TosaArgGen._add_data_generators(
3380 testGen,
3381 opName,
3382 dtype,
3383 arg_list,
3384 error_name,
3385 )
3386 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003387 return arg_list
3388
3389 @staticmethod
3390 def agTable(testGen, opName, shapeList, dtype, error_name=None):
3391 arg_list = []
3392
3393 if dtype == DType.INT8:
3394 table = np.int32(
3395 testGen.rng.integers(low=-128, high=128, size=[256])
3396 ).tolist()
3397 else: # INT16
3398 table = np.int32(
3399 testGen.rng.integers(low=-32768, high=32768, size=[513])
3400 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003401 # Make sure all slopes are within REQUIRE min/max 16-bit int
3402 for idx in range(len(table) - 1):
3403 slope = table[idx + 1] - table[idx]
3404 # Alter the next table entry to force the slope to be ok
3405 if slope > 32767:
3406 table[idx + 1] -= slope - 32767
3407 if slope < -32768:
3408 table[idx + 1] -= slope + 32768
3409 slope = table[idx + 1] - table[idx]
3410 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003411 arg_list.append(
3412 (
3413 "",
3414 [table],
3415 )
3416 )
3417 return arg_list
3418
3419 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
3420 # CondIf generates the condition values here.
3421 # Convert to tensors in the build function, along with the
3422 # then and else blocks
3423 arg_list = []
3424
3425 for c in [False, True]:
3426 arg_list.append(("cond{}".format(int(c)), [c]))
3427
3428 return arg_list
3429
3430 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
3431 # While loop: 0 iterations, 1, more than 1
3432 arg_list = []
3433
3434 for iter in [0, 1, 4]:
3435 arg_list.append(("iter{}".format(iter), [iter]))
3436
3437 return arg_list