blob: 386e243c08d46c378705ca0de261b0e6bd9dd560 [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
19
20class TosaQuantGen:
21 """QuantizedInfo random generator helper functions.
22
23 Specify with 'qgen': in the operator defintion.
24 """
25
26 def __init__(self):
27 pass
28
29 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010031
32 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010033 if testGen.args.zeropoint is not None:
34 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035 return testGen.randInt(-128, 128)
36 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010037 if testGen.args.zeropoint is not None:
38 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 return testGen.randInt(0, 256)
40 elif error_name in [
41 ErrorIf.InputZeroPointNotZero,
42 ErrorIf.WeightZeroPointNotZero,
43 ErrorIf.OutputZeroPointNotZero,
44 ]:
45 zero_point = testGen.randInt(-128, 128)
46 if zero_point == 0:
47 zero_point = 1
48 return zero_point
49 return 0
50
51 @staticmethod
52 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010053 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000054 qinfo = [
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 TosaQuantGen.getZeroPoint(testGen, dtype),
57 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010058 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000059 qinfo = [
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
62 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010063 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000064 qinfo = [
65 TosaQuantGen.getZeroPoint(testGen, dtype),
66 TosaQuantGen.getZeroPoint(testGen, dtype),
67 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068 return qinfo
69
70 @staticmethod
71 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010072 if isinstance(dtype_or_dtypeList, list):
73 # a list of [input, weights, accumulator] dtypes
74 dtypeList = dtype_or_dtypeList
75 else:
76 # an int, [input, weights, accumulator] dtypes are the same
77 dtypeList = [dtype_or_dtypeList] * 3
78
79 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000080 qinfo = [
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
82 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
83 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000085 qinfo = [
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
87 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
88 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000090 qinfo = [
91 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
92 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
93 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 return qinfo
95
96 @staticmethod
97 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000099 qinfo = [
100 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
101 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
102 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000104 qinfo = [
105 TosaQuantGen.getZeroPoint(testGen, dtype),
106 TosaQuantGen.getZeroPoint(testGen, dtype),
107 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108 return qinfo
109
110 @staticmethod
111 def computeMultiplierAndShift(scaleFp, scale32):
112 # Derived from computeMultiplierAndShiftTosaScale32
113 # Provide a floating-point scaling factor and the scale32 parameter
114 # to compute the multiplier and shift
115
116 if scale32:
117 scaleBits = 31
118 else:
119 scaleBits = 15
120
121 m, shift = math.frexp(scaleFp)
122
123 if scaleFp < 0.0:
124 m = -m
125
126 multiplier = round(m * (1 << scaleBits))
127 assert multiplier <= (1 << scaleBits)
128
129 if multiplier == (1 << scaleBits):
130 multiplier = multiplier // 2
131 shift = shift + 1
132
133 shift = (-shift) + scaleBits
134 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
135 # scaleFp, scaleBits, m, multiplier, shift))
136
137 # Adjust multiplier such that shift is in allowed value range.
138 if shift == 0:
139 multiplier = multiplier // 4
140 shift = shift + 2
141 elif shift == 1:
142 multiplier = multiplier // 2
143 shift = shift + 1
144 elif shift == 63:
145 multiplier = multiplier * 2
146 shift = shift - 1
147
148 assert multiplier <= (1 << scaleBits)
149 assert shift >= 2 and shift <= 62
150
151 return multiplier, shift
152
153
154class TosaTensorGen:
155 """Tensor generators create a shape list for the placeholder and const tensor
156 data operands for the operator.
157
158 The actual random data is generated separately for each test.
159 """
160
161 def __init__(self):
162 pass
163
164 @staticmethod
165 def tgBasic(testGen, opName, rank, error_name=None):
166 pl, const = opName["operands"]
167 shape = testGen.makeShape(rank)
168
169 # Constrict the overall size of the shape when creating ERROR_IF tests
170 if error_name:
171 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
172
173 shape_list = []
174 for i in range(pl + const):
175 shape_list.append(shape.copy())
176
Luke Huttona4e48ca2023-02-22 11:53:48 +0000177 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 if error_name == ErrorIf.RankMismatch:
179 if rank == 1 and i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
181 elif i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
183
184 return shape_list
185
186 @staticmethod
187 def tgNHWC(testGen, opName, rank, error_name=None):
188 pl, const = opName["operands"]
189
190 if error_name != ErrorIf.WrongRank:
191 assert rank == 4
192
193 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000194 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name and error_name != ErrorIf.MaxDimExceeded:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000207 def tgGather(testGen, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100208 pl, const = opName["operands"]
209
210 assert pl == 2
211 assert const == 0
212 if error_name != ErrorIf.WrongRank:
213 assert rank == 3
214
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000215 values_shape = testGen.makeShape(rank)
216 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100217
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000218 N = values_shape[0]
219 W = testGen.makeDimension()
220 indices_shape = [N, W]
221
222 shape_list = [values_shape, indices_shape]
223 return shape_list
224
225 @staticmethod
226 def tgScatter(testGen, opName, rank, error_name=None):
227 pl, const = opName["operands"]
228
229 assert pl == 3
230 assert const == 0
231 if error_name != ErrorIf.WrongRank:
232 assert rank == 3
233
234 values_in_shape = testGen.makeShape(rank)
235 values_in_shape = testGen.constrictBatchSize(values_in_shape)
236
237 N = values_in_shape[0]
238 K = values_in_shape[1]
239 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100240
Jeremy Johnson194fe312023-12-07 14:17:57 +0000241 # Make sure W is not greater than K, as we can only write each output index
242 # once (having a W greater than K means that you have to repeat a K index)
243 W_min = min(testGen.args.tensor_shape_range[0], K)
244 W_max = min(testGen.args.tensor_shape_range[1], K)
245 W = testGen.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100246
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000247 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100248
249 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000250 shape_list.append(values_in_shape)
251 shape_list.append([N, W]) # indices
252 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100253
254 return shape_list
255
256 @staticmethod
257 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
258 shape = testGen.makeShape(rank)
259
260 pl, const = op["operands"]
261
262 shape_list = []
263
264 # Choose one of the inputs to broadcast
265 # Note: Simplifies OutputShaper code if we don't change first shape for errors
266 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Jerry Ge135c9552023-05-23 20:59:32 +0000267 fuzz_idx = testGen.randInt(0, rank)
268
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100269 for i in range(pl + const):
270 shape_bcast = shape.copy()
271
Jerry Ge135c9552023-05-23 20:59:32 +0000272 # To test broadcasting, the chosen fuzz index dimension should not be 1
273 if shape_bcast[fuzz_idx] == 1:
274 shape_bcast[fuzz_idx] += 1
275
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100276 # If the chosen input, pick a random index to broadcast
277 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000278 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100279 # Add one rank to the shape (or more for rank of 1)
280 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
281 shape_bcast = np.concatenate(
282 (shape_bcast, testGen.makeShape(extra_ranks))
283 )
284 if rank != 1:
285 # Either keep the extra rank, or remove it
286 new_len = testGen.rng.choice([-2, len(shape_bcast)])
287 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000288 elif error_name == ErrorIf.BroadcastShapesMismatch:
289 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100290 else:
291 shape_bcast[fuzz_idx] = 1
292
293 shape_list.append(shape_bcast)
294
295 return shape_list
296
297 @staticmethod
298 def tgConv2D(testGen, op, rank, error_name=None):
299 pl, const = op["operands"]
300
301 if error_name != ErrorIf.WrongRank:
302 assert rank == 4
303
304 # IFM dimensions are NHWC
305 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000306 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100307
308 # Constrict the overall size of the shape when creating ERROR_IF tests
309 if error_name:
310 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
311 ifm_shape, max_dim=24, max_items=10000
312 )
313
314 # Get the filter height/width from the operator parameters
315 filter_hw = op["filter"]
316
317 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000318 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100319
320 # The filter dimensions are OHWI
321 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
322
323 # The bias is OC
324 bias_shape = np.asarray([ofm_depth])
325
326 return [ifm_shape, filter_shape, bias_shape]
327
328 @staticmethod
329 def tgConv3D(testGen, op, rank, error_name=None):
330 pl, const = op["operands"]
331
332 if error_name != ErrorIf.WrongRank:
333 assert rank == 5
334
335 # IFM dimensions are NDHWC
336 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000337 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100338
339 # Constrict the overall size of the shape when creating ERROR_IF tests
340 if error_name:
341 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
342 ifm_shape, max_dim=24, max_items=10000
343 )
344
345 # Get the filter depth/height/width from the operator parameters
346 filter_dhw = op["filter"]
347
348 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000349 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100350
351 # The filter dimensions are ODHWI
352 filter_shape = np.asarray(
353 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
354 )
355
356 # The bias is OC
357 bias_shape = np.asarray([ofm_channel])
358
359 return [ifm_shape, filter_shape, bias_shape]
360
361 @staticmethod
362 def tgTransposeConv2D(testGen, op, rank, error_name=None):
363 pl, const = op["operands"]
364
365 if error_name != ErrorIf.WrongRank:
366 assert rank == 4
367
368 # IFM dimensions are NHWC
369 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000370 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100371
372 # Constrict the overall size of the shape when creating ERROR_IF tests
373 if error_name:
374 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
375 ifm_shape, max_dim=24, max_items=10000
376 )
377
378 # Get the filter height/width from the operator parameters
379 filter_hw = op["filter"]
380
381 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000382 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100383
384 # The filter dimensions are OHWI
385 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
386
387 # The bias is OC
388 bias_shape = np.asarray([ofm_depth])
389
390 return [ifm_shape, filter_shape, bias_shape]
391
392 @staticmethod
393 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
394 pl, const = op["operands"]
395
396 if error_name != ErrorIf.WrongRank:
397 assert rank == 4
398 assert pl == 1 and const == 2
399
400 # IFM dimensions are NHWC
401 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000402 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100403
404 # Constrict the overall size of the shape when creating ERROR_IF tests
405 if error_name:
406 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
407 ifm_shape, max_dim=24, max_items=10000
408 )
409
410 # Get the filter height/width from the operator parameters
411 # Filter is KH, HW, C, M
412 filter_hw = op["filter"]
413
414 # Generate a random OFM depth, but don't let it get too big because
415 # the output depth is M * C
416 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000417 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100418 ) + 1
419
420 # The filter dimensions are HWCM
421 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
422
423 # The bias is M * C
424 bias_shape = np.asarray([ifm_shape[3] * filter_m])
425
426 return [ifm_shape, filter_shape, bias_shape]
427
428 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000429 def tgFFT2d(testGen, op, rank, error_name=None):
430 pl, const = op["operands"]
431
432 if error_name != ErrorIf.WrongRank:
433 assert rank == 3
434 assert pl == 2 and const == 0
435
436 # IFM dimensions are NHW
437 ifm_shape = testGen.makeShape(rank)
438
439 # Select nearest lower power of two from input height and width
440 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
441 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
442
443 # Constrict the overall size of the shape when creating ERROR_IF tests
444 if error_name:
445 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
446
447 # Generate an invalid kernel that is not a power of two
448 if error_name == ErrorIf.KernelNotPowerOfTwo:
449 inc_h = 2 if ifm_shape[1] == 1 else 1
450 inc_w = 2 if ifm_shape[2] == 1 else 1
451 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
452 selected_inc = testGen.rng.choice(inc_choices)
453 ifm_shape[1] += selected_inc[0]
454 ifm_shape[2] += selected_inc[1]
455
456 ifm_shape = testGen.constrictBatchSize(ifm_shape)
457
458 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
459 if error_name == ErrorIf.FFTInputShapeMismatch:
460 modify_shape = testGen.rng.choice([0, 1])
461 # Only modify kernel (H, W)
462 modify_dim = testGen.rng.choice([1, 2])
463 ifm_shapes[modify_shape][modify_dim] *= 2
464
465 return [ifm_shapes[0], ifm_shapes[1]]
466
467 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000468 def tgRFFT2d(testGen, op, rank, error_name=None):
469 pl, const = op["operands"]
470
471 if error_name != ErrorIf.WrongRank:
472 assert rank == 3
473 assert pl == 1 and const == 0
474
475 # IFM dimensions are NHW
476 ifm_shape = testGen.makeShape(rank)
477
478 # Select nearest lower power of two from input height and width
479 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
480 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
481
482 # Constrict the overall size of the shape when creating ERROR_IF tests
483 if error_name:
484 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
485
486 # Generate an invalid kernel that is not a power of two
487 if error_name == ErrorIf.KernelNotPowerOfTwo:
488 # We must increment by 2 if current size is 1
489 inc_h = 2 if ifm_shape[1] == 1 else 1
490 inc_w = 2 if ifm_shape[2] == 1 else 1
491 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
492 selected_inc = testGen.rng.choice(inc_choices)
493 ifm_shape[1] += selected_inc[0]
494 ifm_shape[2] += selected_inc[1]
495
James Ward30124a82023-02-02 14:56:33 +0000496 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000497
498 return [ifm_shape]
499
500 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100501 def tgFullyConnected(testGen, op, rank, error_name=None):
502 pl, const = op["operands"]
503
504 if error_name != ErrorIf.WrongRank:
505 assert rank == 2
506
507 input_shape = testGen.makeShape(rank)
508
509 # Constrict the overall size of the shape when creating ERROR_IF tests
510 if error_name:
511 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
512
513 filter_oc = testGen.rng.integers(
514 low=testGen.args.tensor_shape_range[0],
515 high=testGen.args.tensor_shape_range[1],
516 size=1,
517 )[0]
518 filter_shape = np.asarray([filter_oc, input_shape[1]])
519
520 bias_shape = np.asarray([filter_oc])
521
522 return [input_shape, filter_shape, bias_shape]
523
524 @staticmethod
525 def tgMatmul(testGen, op, rank, error_name=None):
526 pl, const = op["operands"]
527
528 if error_name != ErrorIf.WrongRank:
529 assert rank == 3
530 assert pl == 2 and const == 0
531
532 a_shape = testGen.makeShape(rank)
533
534 # Constrict the overall size of the shape when creating ERROR_IF tests
535 if error_name:
536 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
537
538 # Get a random number for b_oc even if target shape is defined
539 b_oc = np.int32(
540 testGen.rng.integers(
541 low=testGen.args.tensor_shape_range[0],
542 high=testGen.args.tensor_shape_range[1],
543 size=1,
544 )
545 )[0]
546 # If N or H is large let b_oc be 1 to reduce output tensor size
547 if max(a_shape) > 1000:
548 b_oc = 1
549
550 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
551 return [a_shape, b_shape]
552
553 @staticmethod
554 def tgConcat(testGen, opName, rank, error_name=None):
555 pl, const = opName["operands"]
556 shape = testGen.makeShape(rank)
557
558 # Create extra tensors to concat.
559 # Take into account value of pl when getting maximum number of concats
560 num_tensors = testGen.randInt(0, 4)
561 shape_list = []
562 for i in range(pl + const + num_tensors):
563 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
564 remove = testGen.rng.choice([True, False])
565 wrongShape = shape.copy()
566
567 if remove and len(shape) > 1:
568 wrongShape = wrongShape[1:]
569 else:
570 wrongShape = list(wrongShape)
571 wrongShape.append(testGen.rng.integers(1, 10))
572
573 shape_list.append(wrongShape)
574 else:
575 shape_list.append(shape.copy())
576
577 return shape_list
578
579 @staticmethod
580 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
581 if error_name in [
582 ErrorIf.AxisSmallerZero,
583 ErrorIf.AxisLargerRank,
584 ErrorIf.ConcatInputRankMismatch,
585 ]:
586 return shapeList
587
588 # Split concat shape along axis to allow for multiple const inputs
589 # without making too many large tensors
590 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
591 # If axis can't be split we still need to invalidate other dimensions
592 if error_name == ErrorIf.ConcatInputDimMismatch:
593 for shape in shapeList[1:]:
594 # Negative test shapeLists are created individually for each test,
595 # so no need to copy the shape before altering it.
596 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
597 return shapeList
598
599 # Create copy of shape we are going to split (so we don't alter shapeList)
600 shape = shapeList[0].copy()
601 # Add original shape as first input
602 new_shapeList = [shape.copy()]
603 length_on_axis = shape[axis]
604 remaining_length = length_on_axis
605 for i in range(len(shapeList) - 2):
606 # Calculate split on axis and remaining value
607 split_shape_val = int(shape[axis] / 2)
608 remaining_length = remaining_length - split_shape_val
609
610 # Append new shape, and set remaining shape
611 shape[axis] = split_shape_val
612 new_shapeList.append(shape.copy())
613
614 # invalidate dimensions
615 if error_name == ErrorIf.ConcatInputDimMismatch:
616 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
617 else:
618 shape[axis] = remaining_length
619
620 if i == len(shapeList) - 3:
621 new_shapeList.append(shape.copy())
622
623 return new_shapeList
624
625
626class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100627 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100628
629 def __init__(self):
630 pass
631
Jeremy Johnson1271c442023-09-05 11:39:26 +0100632 class TVGInfo:
633 """Enhanced tensor values information including data gen dict."""
634
635 def __init__(self, tensorList, dataGenDict):
636 self.tensorList = tensorList
637 self.dataGenDict = dataGenDict
638
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100639 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000640 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100641 pCount, cCount = op["operands"]
642
643 tens = []
644 tens.extend(
645 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
646 )
647 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
648
649 return tens
650
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100651 # Default high value for random numbers
652 TVG_FLOAT_HIGH_VALUE = {
653 DType.FP32: (1 << 128) - (1 << (127 - 23)),
654 DType.FP16: (1 << 16) - (1 << (15 - 10)),
655 DType.BF16: (1 << 128) - (1 << (127 - 7)),
656 }
657
Jeremy Johnson30476252023-11-20 16:15:30 +0000658 # Default lowest normal values for random numbers
659 TVG_FLOAT_LOW_VALUE = {
660 DType.FP32: np.exp2(-126),
661 DType.FP16: np.exp2(-14),
662 DType.BF16: np.exp2(-126),
663 }
664
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100665 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +0000666 def _get_data_range(testGen, dtype, highValueLookup, lowValueLookup=None):
667 # Return a tuple of (low,high) data range values for the given data
668 # type using a combination of per operator table limits, data limits
669 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000670 if dtype in highValueLookup:
Jeremy Johnson30476252023-11-20 16:15:30 +0000671 type_range = testGen.getDTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000672 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000673 if lowValueLookup is not None and dtype in lowValueLookup:
674 low_val = lowValueLookup[dtype]
675 else:
676 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000677 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000678 # respecting the default ranges if more/less than the low/high
679 # values
680 data_range = (
681 max(low_val, type_range[0]),
682 min(high_val, type_range[1]),
683 )
684 if data_range[0] > data_range[1]:
685 # Invalid data range from low to high created due to user
686 # constraints revert to using internal ranges as they are
687 # known to work
688 msg = f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
689 warnings.warn(msg)
690 data_range = (low_val, high_val)
691 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000692 return None
693
694 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100695 def tvgLazyGenDefault(
696 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
697 ):
698 # Variable inputs versus constants
699 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000700 if "p_count" in argsDict:
701 # Override for operators like CONCAT
702 pCount = argsDict["p_count"]
703 cCount = argsDict["c_count"]
704 assert pCount + cCount == len(
705 shapeList
706 ), "Placeholders & Constant tensors must match shapes list"
707
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000708 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100709
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100710 if (
711 error_name is not None
712 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100713 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100714 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000715 # Fall back to internal data gen when dealing with unsupported types or ops
716 data_range = argsDict["data_range"] if "data_range" in argsDict else None
717 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000718 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000719 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000720 if "data_range_list" in argsDict:
721 data_range = argsDict["data_range_list"][idx]["range"]
722 roundMode = (
723 "round" in argsDict["data_range_list"][idx]
724 and argsDict["data_range_list"][idx]["round"] is True
725 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000726 if data_range is not None and dtype not in (
727 DType.FP16,
728 DType.FP32,
729 DType.BF16,
730 ):
731 # Change from inclusive to exclusive range
732 data_range = (data_range[0], data_range[1] + 1)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000733 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000734
735 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
736 arr = np.int64(argsDict["fixed_data"][idx])
737 else:
738 arr = testGen.getRandTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000739 if roundMode:
740 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000741 if idx < pCount:
742 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
743 else:
744 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100745
Jeremy Johnson1271c442023-09-05 11:39:26 +0100746 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
747
748 # Create data generator meta-data
749 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100750 tens_data = {
751 "version": "0.1",
752 "tensors": {},
753 }
754 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100755 for idx, shape in enumerate(shapeList):
756
757 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000758 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
759 tens_meta["generator"] = gtu.DataGenType(
760 gtu.DataGenType.FIXED_DATA
761 ).name
762 else:
763 tens_meta["generator"] = gtu.DataGenType(dg_type).name
764
Jeremy Johnson1271c442023-09-05 11:39:26 +0100765 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
766 tens_meta["shape"] = [int(i) for i in shape]
767 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100768 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100769
770 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100771 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100772 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100773 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100774
775 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
776 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000777 if (
778 tens_meta["generator"]
779 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
780 ):
781 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
782 tens_meta["fixed_data_info"] = info
783 else:
784 # TODO - generate seed for this generator based on test
785 info["rng_seed"] = 42
Jeremy Johnson30476252023-11-20 16:15:30 +0000786
Won Jeon64e4bfe2024-01-18 06:31:55 +0000787 data_range = None
788 if "data_range_list" in argsDict:
789 data_range = argsDict["data_range_list"][idx]["range"]
790 if "round" in argsDict["data_range_list"][idx]:
791 info["round"] = argsDict["data_range_list"][idx]["round"]
792 elif "data_range" in argsDict:
793 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000794
Won Jeon64e4bfe2024-01-18 06:31:55 +0000795 if data_range is None:
796 data_range = testGen.getDTypeRange(
797 dtypeList[idx], high_inclusive=True
798 )
799 info["range"] = [str(v) for v in data_range]
800 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100801 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
802 info = {}
803 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100804 info["ks"] = int(argsDict["ks"])
805 if "acc_type" in argsDict:
806 # Convert type number into JSON name
807 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
808 "json"
809 ]
810 if "kernel" in argsDict:
811 info["kernel"] = [int(k) for k in argsDict["kernel"]]
812 if "axis" in argsDict:
813 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100814 tens_meta["dot_product_info"] = info
815 else:
816 # TODO - other data gen type
817 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100818
819 # Using the finished generate config meta data - generate the data if
820 # needed and assign a tensor name from the serializer
821
822 # Need to generate data when not lazy or for the bias tensor as we need
823 # to work out if the bias data is non-zero for compliance
824 if not testGen.args.lazy_data_gen or (
825 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
826 ):
827 # Give this tensor a temporary name until we get one from the serializer
828 temp_name = f"placeholder_{idx}"
829 dg_tens_meta[temp_name] = tens_meta
830 # Create data now using the temporary name to access meta details
831 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000832 if tens_meta["data_type"] == "SHAPE":
833 # Tensor type SHAPE and Numpy file type must be the same
834 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100835 # Remove the item as we will give it the correct name later
836 del dg_tens_meta[temp_name]
837
838 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
839 # The KS value used by compliance verification is altered when the
840 # bias data is non-zero
841 if max(abs(data)) > 0.0:
842 argsDict["ksb"] = argsDict["ks"] + 1
843
844 if testGen.args.lazy_data_gen:
845 data = None
846
847 if tens_meta["input_type"] == "VARIABLE":
848 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
849 else:
850 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
851
852 tens_ser_list.append(tens)
853 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100854 dg_tens_meta[tens.name] = tens_meta
855
Jeremy Johnson1271c442023-09-05 11:39:26 +0100856 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
857
858 @staticmethod
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000859 def tvgNegate(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100860 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000861 # Integer test
862 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100863 pCount, cCount = op["operands"]
864 assert (
865 pCount == 1 and cCount == 0
866 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100867 # Must create tensors with values within accumulator (int32) negatable
868 # range
869 max_val = (1 << 31) - 1
870 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100871 arr = np.int32(
872 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
873 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000874 tens_ser_list = []
875 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100876 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
877 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000878 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100879 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000880 # ERROR_IF or floating point test
881 return TosaTensorValuesGen.tvgLazyGenDefault(
882 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100883 )
884
Jeremy Johnson30476252023-11-20 16:15:30 +0000885 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000886 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
887 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
888 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
889 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
890 }
891
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100892 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000893 def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon74342e52024-01-09 00:34:40 +0000894 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000895 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100896 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000897 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100898 pCount, cCount = op["operands"]
899 assert (
900 pCount == 2 and cCount == 0
901 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000902 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000903 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
904 data_range = testGen.args.tensor_shape_range
905 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0], data_range)
906 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100907 if add:
908 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
909 else:
910 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
911
912 # Work out the saturation limits
913 max_i32 = (1 << 31) - 1
914 min_i32 = -(1 << 31)
915 max_arr = np.full(shapeList[1], max_i32)
916 min_arr = np.full(shapeList[1], min_i32)
917
918 # Find how much values exceed the maximum/minimums
919 sat_max_arr = np.maximum(res_arr - max_arr, 0)
920 sat_min_arr = np.minimum(res_arr - min_arr, 0)
921
922 if not add:
923 # Swap saturation values and negate values as we need to perform opposite operations
924 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
925
926 # Create new array of unsaturated values by clipping values as needed
927 b_unsat_arr = b_arr
928 if (sat_max_arr != 0).any():
929 # Clip values that cause saturation
930 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
931 # Reduce axes in unsaturated tensor to match original tensor
932 for axis, dim in enumerate(b_arr.shape):
933 if dim != b_unsat_arr.shape[axis]:
934 assert (
935 dim == 1
936 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
937 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
938
939 if (sat_min_arr != 0).any():
940 # Clip values that cause saturation
941 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
942 # Reduce axes in unsaturated tensor to match original tensor
943 for axis, dim in enumerate(b_arr.shape):
944 if dim != b_unsat_arr.shape[axis]:
945 assert (
946 dim == 1
947 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
948 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
949
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000950 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100951 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
952 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000953 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100954 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
955 )
956
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000957 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100958 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000959 # ERROR_IF or floating point test
960 data_range = TosaTensorValuesGen._get_data_range(
961 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
962 )
963 if data_range:
964 argsDict["data_range"] = data_range
965
966 return TosaTensorValuesGen.tvgLazyGenDefault(
967 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100968 )
969
970 @staticmethod
971 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000972 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100973 ):
974 if dtypeList[0] in (
975 DType.INT32,
976 DType.INT16,
977 DType.INT8,
978 ):
979 # Limit input tensors with cond_if_binary or while_loop to stop
980 # saturation of add/sub ops with int32 and keep all logical shift
981 # values between 0 to 31 for int16 or int8
982 pCount, cCount = op["operands"]
983 pRemain = pCount
984 placeholders = []
985 for idx, shape in enumerate(shapeList[:]):
986 if dtypeList[0] == DType.INT32:
987 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
988 else:
989 arr = np.int32(
990 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
991 )
992 if pRemain > 0:
993 placeholders.append(
994 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
995 )
996 pRemain -= 1
997 else:
998 placeholders.append(
999 testGen.ser.addConst(shape, dtypeList[idx], arr)
1000 )
1001
1002 return placeholders
1003 else:
1004 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001005 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001006 )
1007
1008 @staticmethod
1009 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001010 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001011 ):
1012 pCount, cCount = op["operands"]
1013 # Force value of operand[1] to be within [0, num_bits]
1014 assert (
1015 pCount == 2 and cCount == 0
1016 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1017
1018 placeholders = []
1019 for idx, shape in enumerate(shapeList[:]):
1020 if idx == 1:
1021 if dtypeList[idx] == DType.INT8:
1022 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
1023 elif dtypeList[idx] == DType.INT16:
1024 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
1025 elif dtypeList[idx] == DType.INT32:
1026 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
1027 elif error_name == ErrorIf.WrongInputType:
1028 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
1029 else:
1030 raise Exception("OpArithmeticRightShift: invalid input dtype")
1031 else:
1032 arr = testGen.getRandTensor(shape, dtypeList[idx])
1033 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
1034
1035 return placeholders
1036
1037 @staticmethod
Won Jeon64e4bfe2024-01-18 06:31:55 +00001038 def tvgReshape(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
1039 dtypeList[1] = DType.SHAPE
1040 shapeList[1] = [len(argsDict["new_shape"])]
1041 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1042 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1043
1044 return TosaTensorValuesGen.tvgLazyGenDefault(
1045 testGen, op, dtypeList, shapeList, argsDict, error_name
1046 )
1047
1048 @staticmethod
1049 def tvgTile(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
1050 dtypeList[1] = DType.SHAPE
1051 shapeList[1] = [len(argsDict["multiples"])]
1052 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1053
1054 return TosaTensorValuesGen.tvgLazyGenDefault(
1055 testGen, op, dtypeList, shapeList, argsDict, error_name
1056 )
1057
1058 @staticmethod
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001059 def tvgSelect(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001060 # Set datatype of condition tensor to boolean
1061 dtypeList[0] = DType.BOOL
1062
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001063 return TosaTensorValuesGen.tvgLazyGenDefault(
1064 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001065 )
1066
1067 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001068 def tvgIntDiv(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001069 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001070 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001071 pCount, cCount = op["operands"]
1072 assert (
1073 pCount == 2 and cCount == 0
1074 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1075
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001076 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001077
1078 # Two invalid cases for Op.INTDIV:
1079 # 1. divisor == 0
1080 # 2. dividend == -(1<<31) and divisor == -1
1081 while True:
1082 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1083 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
1084
1085 if (divisor_arr == 0).any():
1086 continue
1087
1088 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1089 continue
1090
1091 break
1092
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001093 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001094 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1095 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001096 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001097 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1098 )
1099
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001100 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001101 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001102 return TosaTensorValuesGen.tvgLazyGenDefault(
1103 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001104 )
1105
Jeremy Johnson30476252023-11-20 16:15:30 +00001106 # Set the MUL data range to the square root of the largest value
1107 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001108 TVG_FLOAT_HIGH_VALUE_MUL = {
1109 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1110 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1111 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1112 }
1113
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001114 @staticmethod
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001115 def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1116 if error_name is not None or dtypeList[0] in (
1117 DType.FP16,
1118 DType.BF16,
1119 DType.FP32,
1120 ):
1121 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001122 data_range = TosaTensorValuesGen._get_data_range(
1123 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
1124 )
1125 if data_range:
1126 argsDict["data_range"] = data_range
1127
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001128 return TosaTensorValuesGen.tvgLazyGenDefault(
1129 testGen, opName, dtypeList, shapeList, argsDict, error_name
1130 )
1131 else:
1132 # Integer test
1133 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001134 pCount, cCount = op["operands"]
1135 assert (
1136 pCount == 2 and cCount == 0
1137 ), "Op.MUL must have 2 placeholders, 0 consts"
1138
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001139 tens_ser_list = []
1140
1141 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001142 if dtypeList[0] == DType.SHAPE:
1143 shift = 0
1144 else:
1145 shift = argsDict["shift"]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001146 if dtypeList[0] == DType.INT8:
1147 num_bits = 8
1148 elif dtypeList[0] == DType.INT16:
1149 num_bits = 16
Won Jeon74342e52024-01-09 00:34:40 +00001150 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001151 num_bits = 32
1152 elif error_name == ErrorIf.WrongInputType:
1153 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001154 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001155 raise Exception("OpMul: invalid input dtype")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001156
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001157 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001158 if dtypeList[idx] == DType.SHAPE:
1159 low = testGen.args.tensor_shape_range[0]
1160 high = testGen.args.tensor_shape_range[1]
1161 else:
1162 low = -(2 ** (num_bits - 1))
1163 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001164
1165 a_arr = np.int32(
1166 testGen.rng.integers(low=low, high=high, size=shapeList[0])
1167 )
1168 b_arr = np.int32(
1169 testGen.rng.integers(low=low, high=high, size=shapeList[1])
1170 )
1171
1172 i = 0
1173 while True:
1174
1175 a_arr_64 = a_arr.astype(np.int64)
1176 b_arr_64 = b_arr.astype(np.int64)
1177
1178 if shift > 0:
1179 rounding = 1 << (shift - 1)
1180 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001181 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001182 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001183
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001184 if (result_arr > -(2**31)).all() and (
1185 result_arr <= ((2**31) - 1)
1186 ).all():
1187 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001188
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001189 i = i + 1
1190 a_arr = a_arr // 2
1191 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001192
Won Jeon74342e52024-01-09 00:34:40 +00001193 if dtypeList[0] == DType.SHAPE:
1194 tens_ser_list.append(
1195 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1196 )
1197 tens_ser_list.append(
1198 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1199 )
1200 else:
1201 tens_ser_list.append(
1202 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1203 )
1204 tens_ser_list.append(
1205 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1206 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001207
1208 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001209
1210 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001211 def tvgConcat(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001212 count = len(shapeList) - testGen.args.num_const_inputs_concat
1213 if count < 1:
1214 count = 1
1215 if testGen.args.num_const_inputs_concat == 0:
1216 count = len(shapeList)
1217
Won Jeon74342e52024-01-09 00:34:40 +00001218 op = testGen.TOSA_OP_LIST[opName]
1219 if op["op"] == Op.CONCAT_SHAPE:
1220 # Set the axis to 0
1221 shapeList = TosaTensorGen.tgConcatConstInput(
1222 testGen, shapeList, 0, error_name
1223 )
1224 else:
1225 shapeList = TosaTensorGen.tgConcatConstInput(
1226 testGen, shapeList, argsDict["axis"], error_name
1227 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001228
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001229 # Override default pCount/cCount for operator
1230 argsDict["p_count"] = count
1231 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001232
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001233 return TosaTensorValuesGen.tvgLazyGenDefault(
1234 testGen, opName, dtypeList, shapeList, argsDict, error_name
1235 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001236
1237 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001238 def tvgLogicalShift(
1239 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1240 ):
1241 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001242 pCount, cCount = op["operands"]
1243 assert (
1244 pCount == 2 and cCount == 0
1245 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
1246 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1247 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001248 tens_ser_list = []
1249 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001250 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1251 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001252 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001253 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1254 )
1255
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001256 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001257
1258 @staticmethod
Jeremy Johnsona0150012023-11-15 15:52:06 +00001259 def tvgEqual(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1260 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1261 # Integer
1262 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001263 pCount, cCount = op["operands"]
1264 assert (
1265 pCount == 2 and cCount == 0
1266 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001267
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001268 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1269 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001270
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001271 # Using random numbers means that it will be very unlikely that
1272 # there are any matching (equal) values, therefore force that
1273 # there are twice the number of matching values as the tensor rank
1274 for num in range(0, len(shapeList[0]) * 2):
1275 a_index = []
1276 b_index = []
1277 # Choose an index in each axis for the whole shape
1278 for axis in range(0, len(shapeList[0])):
1279 # Index can be up to the largest dimension in both shapes
1280 index = np.int32(
1281 testGen.rng.integers(
1282 0, max(shapeList[0][axis], shapeList[1][axis])
1283 )
1284 )
1285 # Reduce the index down to a shape's dim for broadcasting
1286 a_index.append(min(shapeList[0][axis] - 1, index))
1287 b_index.append(min(shapeList[1][axis] - 1, index))
1288
1289 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1290
Jeremy Johnsona0150012023-11-15 15:52:06 +00001291 tens_ser_list = []
1292 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001293 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1294 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001295 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001296 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1297 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001298 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001299 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001300 # ERROR_IF or floating point test
1301 return TosaTensorValuesGen.tvgLazyGenDefault(
1302 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001303 )
1304
1305 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001306 def tvgReduceSum(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001307 dtype = dtypeList[0]
1308 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001309 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001310 pCount, cCount = op["operands"]
1311 assert (
1312 pCount == 1 and cCount == 0
1313 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1314 # Limit values so that the sum cannot exceed the range of an int32 during
1315 # summation of any axis
1316 range_val = int((1 << 31) / max(shapeList[0]))
1317 values_arr = np.int32(
1318 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1319 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001320 tens_ser_list = []
1321 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001322 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001323 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001324 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001325 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001326 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001327 if (
1328 error_name is None
1329 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1330 ):
1331 # Limit ranges for (non error & non compliance) tests by using
1332 # values that can be summed on any axis to not hit infinity
1333 highval_lookup = {
1334 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1335 / max(shapeList[0])
1336 }
1337 data_range = TosaTensorValuesGen._get_data_range(
1338 testGen, dtype, highval_lookup
1339 )
1340 assert data_range is not None
1341 argsDict["data_range"] = data_range
1342
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001343 return TosaTensorValuesGen.tvgLazyGenDefault(
1344 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001345 )
1346
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001347 @staticmethod
1348 def tvgReduceProduct(
1349 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1350 ):
1351 dtype = dtypeList[0]
1352 if error_name is None:
1353 # Limit ranges for (non error) tests by using
1354 # values that can be multiplied on any axis to not hit infinity
1355 highval_lookup = {
1356 dtype: math.pow(
1357 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1358 1 / max(shapeList[0]),
1359 )
1360 }
1361 data_range = TosaTensorValuesGen._get_data_range(
1362 testGen, dtype, highval_lookup
1363 )
1364 assert data_range is not None
1365 argsDict["data_range"] = data_range
1366
1367 return TosaTensorValuesGen.tvgLazyGenDefault(
1368 testGen, opName, dtypeList, shapeList, argsDict, error_name
1369 )
1370
Jeremy Johnson30476252023-11-20 16:15:30 +00001371 # Set the POW exponent high data range
1372 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1373 DType.FP32: 10.0,
1374 DType.FP16: 10.0,
1375 DType.BF16: 10.0,
1376 }
1377 # POW highest base value (within a safe margin of error) that can be raised
1378 # to +ve exponent that doesn't become Infinity
1379 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1380 DType.FP32: math.floor(
1381 math.pow(
1382 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1383 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1384 )
1385 ),
1386 DType.FP16: math.floor(
1387 math.pow(
1388 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1389 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1390 )
1391 ),
1392 DType.BF16: math.floor(
1393 math.pow(
1394 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1395 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1396 )
1397 ),
1398 }
1399 # POW lowest base value (within a safe margin of error) that can be raised
1400 # to -ve exponent that doesn't become Infinity
1401 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1402 DType.FP32: math.ceil(
1403 math.pow(
1404 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1405 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1406 )
1407 * 1000
1408 )
1409 / 1000,
1410 DType.FP16: math.ceil(
1411 math.pow(
1412 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1413 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1414 )
1415 * 1000
1416 )
1417 / 1000,
1418 DType.BF16: math.ceil(
1419 math.pow(
1420 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1421 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1422 )
1423 * 1000
1424 )
1425 / 1000,
1426 }
1427
1428 @staticmethod
1429 def tvgPow(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1430 if error_name is not None:
1431 return TosaTensorValuesGen.tvgLazyGenDefault(
1432 testGen, opName, dtypeList, shapeList, argsDict, error_name
1433 )
1434 dtype = dtypeList[0]
1435 # Different ranges for POW
1436 test_set = argsDict["s"]
1437 if test_set == 0:
1438 # Positive base with fractional exponent
1439 base_range = TosaTensorValuesGen._get_data_range(
1440 testGen,
1441 dtype,
1442 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1443 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1444 )
1445 exp_range = TosaTensorValuesGen._get_data_range(
1446 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1447 )
1448 exp_round = False
1449 else:
1450 # Integer exponent
1451 exp_range = TosaTensorValuesGen._get_data_range(
1452 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1453 )
1454 exp_round = True
1455 if test_set == 1:
1456 # Positive base
1457 base_range = TosaTensorValuesGen._get_data_range(
1458 testGen,
1459 dtype,
1460 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1461 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1462 )
1463 else:
1464 assert test_set == 2
1465 # Negative base
1466 # Supply new look up tables with negative values
1467 base_range = TosaTensorValuesGen._get_data_range(
1468 testGen,
1469 dtype,
1470 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1471 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1472 )
1473
1474 data_range_list = (
1475 {
1476 "range": base_range,
1477 },
1478 {
1479 "range": exp_range,
1480 "round": exp_round,
1481 },
1482 )
1483 argsDict["data_range_list"] = data_range_list
1484 return TosaTensorValuesGen.tvgLazyGenDefault(
1485 testGen, opName, dtypeList, shapeList, argsDict, error_name
1486 )
1487
1488 @staticmethod
1489 def tvgLogRsqrt(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1490 # LOG & RSQRT data range from lowest expressible positive number to
1491 # largest to avoid NaNs
1492 data_range = TosaTensorValuesGen._get_data_range(
1493 testGen,
1494 dtypeList[0],
1495 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1496 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1497 )
1498 if data_range:
1499 argsDict["data_range"] = data_range
1500
1501 return TosaTensorValuesGen.tvgLazyGenDefault(
1502 testGen, opName, dtypeList, shapeList, argsDict, error_name
1503 )
1504
1505 # Set the EXP data range to the log of the largest to smallest values
1506 # to avoid infinities or making the result zero
1507 TVG_FLOAT_HIGH_VALUE_EXP = {
1508 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1509 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1510 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1511 }
1512 TVG_FLOAT_LOW_VALUE_EXP = {
1513 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1514 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1515 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1516 }
1517
1518 @staticmethod
1519 def tvgExp(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1520 data_range = TosaTensorValuesGen._get_data_range(
1521 testGen,
1522 dtypeList[0],
1523 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1524 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1525 )
1526 if data_range:
1527 argsDict["data_range"] = data_range
1528
1529 return TosaTensorValuesGen.tvgLazyGenDefault(
1530 testGen, opName, dtypeList, shapeList, argsDict, error_name
1531 )
1532
1533 @staticmethod
1534 def tvgFullyConnected(
1535 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1536 ):
1537 dtype = dtypeList[0]
1538 if (
1539 error_name is None
1540 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001541 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001542 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001543 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001544 # Limit ranges for (non error & non compliance) FP tests by using
1545 # values that can be multiplied on any axis to not hit infinity/NaN
1546 IC = shapeList[0][1]
1547 highval_lookup = {
1548 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1549 }
1550 data_range = TosaTensorValuesGen._get_data_range(
1551 testGen, dtype, highval_lookup
1552 )
1553 assert data_range is not None
1554 argsDict["data_range"] = data_range
1555
1556 return TosaTensorValuesGen.tvgLazyGenDefault(
1557 testGen, opName, dtypeList, shapeList, argsDict, error_name
1558 )
1559
Jeremy Johnson708da822023-11-15 16:25:45 +00001560 @staticmethod
1561 def tvgCast(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1562 in_dtype = dtypeList[0]
1563 out_dtype = argsDict["out_type"]
1564 # Create look up to limit input tensor to output type maximums to avoid
1565 # FP infinities and saturation of integers
1566 out_range = testGen.getDTypeRange(out_dtype, high_inclusive=True)
1567 highval_lookup = {in_dtype: out_range[1]}
1568 data_range = TosaTensorValuesGen._get_data_range(
1569 testGen,
1570 in_dtype,
1571 highval_lookup,
1572 )
1573
1574 assert data_range is not None
1575 argsDict["data_range"] = data_range
1576
1577 return TosaTensorValuesGen.tvgLazyGenDefault(
1578 testGen, opName, dtypeList, shapeList, argsDict, error_name
1579 )
1580
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001581 @staticmethod
1582 def tvgGather(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1583 K = shapeList[0][1]
1584
1585 # Fix the type of the indices tensor
1586 dtypeList[1] = DType.INT32
1587
1588 dtype = dtypeList[0]
1589 if not gtu.dtypeIsSupportedByCompliance(dtype):
1590 # Test unsupported by data generator
1591 op = testGen.TOSA_OP_LIST[opName]
1592 pCount, cCount = op["operands"]
1593 assert (
1594 pCount == 2 and cCount == 0
1595 ), "Op.GATHER must have 2 placeholders, 0 consts"
1596
1597 tens_ser_list = []
1598 for idx, shape in enumerate(shapeList):
1599 dtype = dtypeList[idx]
1600 if idx != 1:
1601 arr = testGen.getRandTensor(shape, dtype)
1602 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1603 else:
1604 # Limit data range of indices tensor upto K (exclusive)
1605 arr = testGen.getRandTensor(shape, dtype, (0, K))
1606 # To match old functionality - create indices as CONST
1607 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1608
1609 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1610
1611 else:
1612 # ERROR_IF or floating point test
1613 # Use inclusive values upto index K for indices tensor
1614 data_range_list = (
1615 {"range": None},
1616 {"range": (0, K - 1)},
1617 )
1618 argsDict["data_range_list"] = data_range_list
1619
1620 return TosaTensorValuesGen.tvgLazyGenDefault(
1621 testGen, opName, dtypeList, shapeList, argsDict, error_name
1622 )
1623
1624 @staticmethod
1625 def tvgScatter(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1626 K = shapeList[0][1]
1627 W = shapeList[2][1]
1628
1629 # Work out an indices tensor here with data that doesn't exceed the
1630 # dimension K of the values_in tensor and does NOT repeat the same K
1631 # location as needed by the spec:
1632 # "It is not permitted to repeat the same output index within a single
1633 # SCATTER operation and so each output index occurs at most once."
1634 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1635
1636 # Fix the type of the indices tensor
1637 dtypeList[1] = DType.INT32
1638
1639 dtype = dtypeList[0]
1640 if not gtu.dtypeIsSupportedByCompliance(dtype):
1641 # Test unsupported by data generator
1642 op = testGen.TOSA_OP_LIST[opName]
1643 pCount, cCount = op["operands"]
1644 assert (
1645 pCount == 3 and cCount == 0
1646 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1647
1648 tens_ser_list = []
1649 for idx, shape in enumerate(shapeList):
1650 dtype = dtypeList[idx]
1651 if idx != 1:
1652 arr = testGen.getRandTensor(shape, dtype)
1653 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1654 else:
1655 # Create the indices array
1656 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1657 arr = []
1658 for n in range(shape[0]):
1659 # Get a shuffled list of output indices (0 to K-1) and
1660 # limit length to W
1661 arr.append(testGen.rng.permutation(K)[:W])
1662 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1663 # To match old functionality - create indices as CONST
1664 tens_ser_list.append(
1665 testGen.ser.addConst(shape, dtype, indices_arr)
1666 )
1667
1668 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1669
1670 else:
1671 # ERROR_IF or floating point test
1672 # Use inclusive values upto index K for indices tensor
1673 data_range_list = (
1674 {"range": None},
1675 {"range": (0, K - 1)},
1676 {"range": None},
1677 )
1678 argsDict["data_range_list"] = data_range_list
1679
1680 return TosaTensorValuesGen.tvgLazyGenDefault(
1681 testGen, opName, dtypeList, shapeList, argsDict, error_name
1682 )
1683
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001684
1685class TosaArgGen:
1686 """Argument generators create exhaustive or random lists of attributes for
1687 operators that take attributes or other parameters.
1688
1689 The return value is a list of (descriptive_name, [arglist]) tuples where
1690 the descriptive_name is appended to the test name and the arglist is expanded
1691 as arguments to the operator build function.
1692 """
1693
1694 def __init__(self):
1695 pass
1696
1697 @staticmethod
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001698 def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001699 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001700 if (
1701 error_name is None
1702 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1703 and gtu.dtypeIsSupportedByCompliance(dtype)
1704 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001705 if dtype in [DType.FP16, DType.FP32, DType.BF16]:
1706 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1707 else:
1708 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1709 else:
1710 # Error test or No data generator types listed - assume random
1711 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1712
1713 # Expand arg list with other data generator types
1714 new_arg_list = []
1715 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001716 for arg_str, args_dict in arg_list:
1717 args_dict["dg_type"] = dg_type
Jeremy Johnson1271c442023-09-05 11:39:26 +01001718 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001719 if error_name is None:
1720 num_test_sets = (
1721 args_dict["num_test_sets"]
1722 if "num_test_sets" in args_dict
1723 else 0
1724 )
1725 else:
1726 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001727
1728 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1729 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001730 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001731 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001732 shape_info = (
1733 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1734 if "shape" in args_dict
1735 else ""
1736 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001737 print(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001738 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001739 )
1740 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001741 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001742 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001743 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001744
Jeremy Johnson30476252023-11-20 16:15:30 +00001745 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1746
1747 if num_test_sets > 0:
1748 for s in range(0, num_test_sets):
1749 new_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001750 new_args_dict = args_dict.copy()
1751 new_args_dict["s"] = s
1752 new_arg_list.append((new_arg_str, new_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001753 else:
1754 # Default is a single test
1755 new_arg_list.append((arg_str, args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001756
1757 return new_arg_list
1758
1759 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001760 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1761 """A trivial argument generator for operators that don't take any
1762 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001763 arg_list = TosaArgGen._add_data_generators(
1764 testGen,
1765 opName,
1766 dtype,
1767 [("", {})],
1768 error_name,
1769 )
1770 # Return list of tuples: (arg_str, args_dict)
1771 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001772
1773 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +00001774 def agPow(testGen, opName, shapeList, dtype, error_name=None):
1775 """Pow operator needs different test sets to cover random numbers
1776 without creating NaNs or Infs"""
1777 arg_list = TosaArgGen._add_data_generators(
1778 testGen,
1779 opName,
1780 dtype,
1781 [("", {"num_test_sets": 3})],
1782 error_name,
1783 )
1784 # Return list of tuples: (arg_str, args_dict)
1785 return arg_list
1786
1787 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001788 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1789 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001790 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001791 shape = shapeList[0]
1792
1793 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001794 # Set too small axis
1795 axes = [testGen.rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001796 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001797 # Set too large axis
1798 axes = [testGen.rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001799 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001800 # Create tests for each dimension
1801 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001802
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001803 opid = testGen.TOSA_OP_LIST[opName]["op"]
1804
1805 for a in axes:
1806 args_dict = {"axis": int(a)}
1807 if opid == Op.REDUCE_SUM:
1808 args_dict["dot_products"] = gtu.product(shape)
1809 args_dict["shape"] = shape
1810 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1811 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1812
1813 arg_list.append(("axis{}".format(a), args_dict))
1814
1815 arg_list = TosaArgGen._add_data_generators(
1816 testGen,
1817 opName,
1818 dtype,
1819 arg_list,
1820 error_name,
1821 )
1822 # Return list of tuples: (arg_str, args_dict)
1823 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001824
1825 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001826 def _calculate_sparsity(num_tests, sparsity_factor):
1827 sparsity = num_tests // sparsity_factor + 1
1828 # If there are only a small number of tests, just select them all
1829 if sparsity < 13:
1830 sparsity = 1
1831 # To get a variety of parameter combinations sparsity should not be a
1832 # multiple of 2, 3 or 5
1833 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1834 sparsity += 1
1835 return sparsity
1836
1837 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001838 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001839 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001840 arg_list = []
1841
Jeremy Johnson0c716862023-04-13 17:18:19 +01001842 if testGen.args.level8k and error_name is not None:
1843 # Don't produce negative large tests
1844 return arg_list
1845
1846 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001847 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001848 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001849 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001850
Jeremy Johnson1271c442023-09-05 11:39:26 +01001851 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001852
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001853 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01001854 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001855 depthwise = opName.startswith("depthwise")
1856
1857 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001858 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001859 if error_name != ErrorIf.WrongRank:
1860 assert len(ifm_shape) == rank
1861 assert len(filter_shape) == rank
1862
Jeremy Johnson0c716862023-04-13 17:18:19 +01001863 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001864 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001865 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01001866 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001867 # compliance size - KS
1868 k_size = gtu.product(k_shape)
1869 if not depthwise:
1870 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001871
Jeremy Johnson0c716862023-04-13 17:18:19 +01001872 if not testGen.args.level8k:
1873 # Generate comprehensive argument lists
1874 # - except for named errors, which use specific invalid value(s)
1875 if error_name == ErrorIf.PadSmallerZero:
1876 p_vals = [testGen.rng.choice(range(-5, 0))]
1877 else:
1878 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1879 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1880 if error_name == ErrorIf.StrideSmallerOne:
1881 # Can't use stride=0, as it is used to derive output shape, as a divisor
1882 s_vals = [testGen.rng.choice(range(-5, 0))]
1883 else:
1884 # Stride must be greater than 1 to force non-integer error
1885 startStride = (
1886 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001887 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001888 s_vals = [
1889 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1890 ]
1891 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1892 if error_name == ErrorIf.DilationSmallerOne:
1893 d_vals = [testGen.rng.choice(range(-5, 1))]
1894 else:
1895 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1896 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001897
Jeremy Johnson0c716862023-04-13 17:18:19 +01001898 if not error_name and testGen.args.oversize:
1899 # add some oversize argument values
1900 if max(ifm_shape) < 64:
1901 bigPadding = 9
1902 paddings.update(
1903 {
1904 x
1905 for x in itertools.product(
1906 *([[0, bigPadding]] * (k_rank * 2))
1907 )
1908 }
1909 )
1910 bigStride = 8
1911 strides.update(
1912 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1913 )
1914 bigDilation = 7
1915 dilations.update(
1916 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1917 )
1918 max_dim_size = None
1919
1920 # There are too many parameter combinations, so generate them sparsely,
1921 # very sparse for negative tests
1922 sparsity_factor = 2 if error_name else 120
1923 sparsity = TosaArgGen._calculate_sparsity(
1924 len(paddings) * len(strides) * len(dilations), sparsity_factor
1925 )
1926 else:
1927 # Only test 8k levels boundaries
1928 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1929 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1930 bigPadding = bigKernel
1931
1932 dilation_shape = [1] * k_rank
1933 pad_shape = [0] * k_rank * 2
1934 if conv3d:
1935 # Small stride apart from for big kernel (see below) to keep
1936 # tensor size/calculation small
1937 stride_shape = [1] * k_rank
1938 for idx in range(k_rank):
1939 pad_offset = idx * 2
1940 if k_shape[idx] == bigKernel:
1941 # Padding shape needs to account for tensor shape
1942 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1943 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1944 # Big stride to reduce output size
1945 stride_shape[idx] = bigKernel
1946 else:
1947 # Account for kernel size
1948 pad_shape[pad_offset] = k_shape[idx] - 1
1949 else:
1950 # Always have a large stride with extra padding and dilation to keep
1951 # tensor calculation reasonable
1952 stride_shape = [bigKernel] * k_rank
1953 for idx in range(k_rank):
1954 # Dilation shape must account for kernel size
1955 dilation_shape[idx] = bigKernel // k_shape[idx]
1956 # Padding shape needs to accommodate tensor/kernel & dilation
1957 pad_offset = idx * 2
1958 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1959 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1960
1961 strides = {tuple(stride_shape)}
1962 dilations = {tuple(dilation_shape)}
1963 paddings = {tuple(pad_shape)}
1964 # Create a limit for the output dimensions size
1965 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1966
1967 # Currently allow all combinations that are reasonable size
1968 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001969
1970 n = 0
1971 for s in sorted(list(strides)):
1972 for p in sorted(list(paddings)):
1973 for d in sorted(list(dilations)):
1974 if (
1975 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001976 # the padded shape must exceed the dilation * kernel to get a positive
1977 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001978 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
1979 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01001980 and (
1981 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01001982 or (
1983 (ifm_shape[3] - 1 + p[4] + p[5])
1984 > d[2] * (k_shape[2] - 1)
1985 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01001986 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001987 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001988 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01001989 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001990 for index in range(k_rank):
1991 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001992 partial = (
1993 ifm_shape[index + 1]
1994 - 1
1995 + p[pad_offset]
1996 + p[pad_offset + 1]
1997 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001998 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001999 remainders.append(partial % s[index])
2000 outputs.append((partial // s[index]) + 1)
2001
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002002 if (
2003 # the parameters must produce integer exact output
2004 error_name != ErrorIf.ConvOutputShapeNonInteger
2005 and max(remainders) == 0
2006 ) or (
2007 error_name == ErrorIf.ConvOutputShapeNonInteger
2008 and max(remainders) > 0
2009 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002010 if (
2011 max_dim_size is not None
2012 and max(outputs) >= max_dim_size
2013 ):
2014 # Test will consume too much memory - skip it
2015 continue
2016
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002017 # Compliance - number of dot product calculations
2018 if depthwise:
Jeremy Johnson4f931302024-01-04 17:05:24 +00002019 # N*OH*OW*C*M
2020 dots = gtu.product(
2021 (ifm_shape[0], *outputs, *filter_shape[2:])
2022 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002023 else:
Jeremy Johnson4f931302024-01-04 17:05:24 +00002024 # N*OH*OW*OC or N*OD*OH*OW*OC
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002025 dots = gtu.product(
2026 (ifm_shape[0], *outputs, filter_shape[0])
2027 )
2028 args_dict = {
2029 "acc_type": accum_dtype,
2030 "stride": s,
2031 "pad": p,
2032 "dilation": d,
2033 "kernel": k_shape,
2034 "ks": k_size,
2035 "dot_products": dots,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002036 "shape": ifm_shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002037 }
2038
Jeremy Johnson0c716862023-04-13 17:18:19 +01002039 # Support for larger values than 9 needs different delimiter
2040 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002041 arg_list.append(
2042 (
James Ward8b390432022-08-12 20:48:56 +01002043 "acc{}_st{}_pad{}_dilat{}".format(
2044 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002045 delim.join([str(x) for x in s]),
2046 delim.join([str(x) for x in p]),
2047 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002048 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002049 args_dict,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002050 )
2051 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002052 n += 1
2053
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002054 arg_list = TosaArgGen._add_data_generators(
2055 testGen,
2056 opName,
2057 dtypes[0],
2058 arg_list,
2059 error_name,
2060 )
2061 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002062 return arg_list
2063
2064 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01002065 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
2066
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002067 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002068 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002069
2070 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002071 accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002072 elif error_name == ErrorIf.WrongInputType:
2073 # Pick some potentially correct output dtype if input type is incorrect
2074 accum_dtype = DType.INT32
2075 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002076 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002077
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002078 # Set up compliance info
2079 args_dict = {
2080 "acc_type": accum_dtype,
2081 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2082 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2083 "shape": shapeList[0],
2084 }
2085
2086 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2087
2088 arg_list = TosaArgGen._add_data_generators(
2089 testGen,
2090 opName,
2091 input_dtype,
2092 arg_list,
2093 error_name,
2094 )
2095 # Return list of tuples: (arg_str, args_dict)
2096 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002097
2098 @staticmethod
2099 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
2100 # Get valid accumulate type(s)
2101 if dtype == DType.INT8:
2102 accum_dtypes = [DType.INT32]
2103 elif dtype == DType.INT16:
2104 accum_dtypes = [DType.INT48]
2105 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002106 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002107 elif dtype == DType.BF16:
2108 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002109 elif dtype == DType.FP32:
2110 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01002111 elif error_name is None:
2112 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2113
2114 if error_name == ErrorIf.WrongOutputType:
2115 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson1271c442023-09-05 11:39:26 +01002116 accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002117 elif error_name == ErrorIf.WrongInputType:
2118 # Pick some potentially correct output dtype if input type is incorrect
2119 accum_dtypes = [DType.INT32]
2120
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002121 # Set up compliance info
2122 args_dict = {
2123 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2124 # Set dot_products = N*H*W
2125 "dot_products": gtu.product(
2126 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2127 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002128 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002129 }
2130
2131 # Create arg tuple of string and dict
2132 arg_list = []
2133 for a in accum_dtypes:
2134 d = args_dict.copy()
2135 d["acc_type"] = a
2136 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002137
2138 arg_list = TosaArgGen._add_data_generators(
2139 testGen,
2140 opName,
2141 dtype,
2142 arg_list,
2143 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002144 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002145 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002146 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002147
2148 @staticmethod
2149 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002150 arg_list = []
2151
Jeremy Johnson0c716862023-04-13 17:18:19 +01002152 if testGen.args.level8k and error_name is not None:
2153 # Don't produce negative large tests
2154 return arg_list
2155
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002156 ifm_shape = shapeList[0]
2157 filter_shape = shapeList[1]
2158
Jeremy Johnson1271c442023-09-05 11:39:26 +01002159 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002160
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002161 # Must be rank 4
2162 if error_name != ErrorIf.WrongRank:
2163 assert len(ifm_shape) == 4
2164 assert len(filter_shape) == 4
2165
Jeremy Johnson0c716862023-04-13 17:18:19 +01002166 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson95a67102024-01-10 14:16:39 +00002167 # compliance size - KS
2168 k_size = gtu.product((*k_shape, ifm_shape[3]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002169
Jeremy Johnson0c716862023-04-13 17:18:19 +01002170 if not testGen.args.level8k:
2171 # Generate comprehensive argument lists
2172 # - except for named errors, which use specific invalid value(s)
2173 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2174 if error_name == ErrorIf.PadLargerEqualKernel:
2175 max_filter_size = -max(k_shape[0], k_shape[1])
2176 p_vals = [
2177 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
2178 ]
2179 else:
2180 p_vals = [
2181 x
2182 for x in range(
2183 smallest_padding_size, testGen.args.max_conv_padding + 1
2184 )
2185 ]
2186 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2187 if error_name == ErrorIf.StrideSmallerOne:
2188 # Can't use stride=0, as it is used to derive output shape, as a divisor
2189 s_vals = [testGen.rng.choice(range(-5, 0))]
2190 else:
2191 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2192 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002193
Jeremy Johnson0c716862023-04-13 17:18:19 +01002194 if not error_name and testGen.args.oversize:
2195 # add some oversize argument values
2196 if max(ifm_shape) < 64:
2197 bigPadding = 9
2198 paddings.update(
2199 {
2200 x
2201 for x in itertools.product(
2202 *([[smallest_padding_size, bigPadding]] * 4)
2203 )
2204 }
2205 )
2206 bigStride = 8
2207 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2208
2209 # There are too many parameter combinations, so generate them sparsely,
2210 # very sparse for negative tests
2211 sparsity_factor = 2 if error_name else 10
2212 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2213 # If there are only a small number of tests, just select them all
2214 if sparsity < 13:
2215 sparsity = 1
2216 # To get a variety of parameter combinations sparsity should not be a
2217 # multiple of 2, 3 or 5
2218 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2219 sparsity += 1
2220 else:
2221 # Only test 8k levels boundaries
2222 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2223 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2224 bigPadding = bigKernel
2225
2226 pad_shape = [0] * (len(k_shape) * 2)
2227 stride_shape = [1] * len(k_shape)
2228 # The point at which input dimension combined with the stride will
2229 # create large output sizes!
2230 LARGE_SIZE = 2
2231 for idx in range(len(k_shape)):
2232 pad_offset = idx * 2
2233 if k_shape[idx] == bigKernel:
2234 # Set large stride
2235 stride_shape[idx] = bigKernel
2236 # Use negative output padding to reduce shape size
2237 pad_shape[pad_offset] = -(bigPadding - 1)
2238 if ifm_shape[idx + 1] > LARGE_SIZE:
2239 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2240 else:
2241 # The other dimension should be the bigKernel
2242 alt_idx = 1 - idx
2243 if (
2244 k_shape[alt_idx] == bigKernel
2245 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2246 ):
2247 # As the input is small, the large stride won't
2248 # affect the output so we can add some padding
2249 pad_shape[pad_offset + 1] = bigPadding
2250
2251 strides = {tuple(stride_shape)}
2252 paddings = {tuple(pad_shape)}
2253
2254 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002255 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002256
2257 n = 0
2258 for s in sorted(list(strides)):
2259 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07002260 if n % sparsity == 0:
2261 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002262 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2263 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07002264 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002265
Jeremy Johnson95a67102024-01-10 14:16:39 +00002266 # N*OH*OW*OC
2267 dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
2268 args_dict = {
2269 "acc_type": accum_dtype,
2270 "stride": s,
2271 "pad": p,
2272 "kernel": k_shape,
2273 "ks": k_size,
2274 "dot_products": dots,
2275 "shape": ifm_shape,
2276 "out_shape": os,
2277 }
2278
Jeremy Johnson0c716862023-04-13 17:18:19 +01002279 # Support for larger values than 9 needs different delimiter
2280 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07002281 arg_list.append(
2282 (
James Ward8b390432022-08-12 20:48:56 +01002283 "acc{}_st{}_pad{}_os{}".format(
2284 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002285 delim.join([str(x) for x in s]),
2286 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07002287 "x".join([str(x) for x in os]),
2288 ),
Jeremy Johnson95a67102024-01-10 14:16:39 +00002289 args_dict,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002290 )
TatWai Chong24594f52022-06-08 00:48:04 -07002291 )
2292 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002293
Jeremy Johnson95a67102024-01-10 14:16:39 +00002294 arg_list = TosaArgGen._add_data_generators(
2295 testGen,
2296 opName,
2297 dtypes[0],
2298 arg_list,
2299 error_name,
2300 )
2301 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002302 return arg_list
2303
2304 @staticmethod
2305 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002306 rank = len(shapeList[0])
2307
2308 # Exhaustively test combinations of padding on each side of each dimension
2309 # - the range of padding values is defined by pad_min and pad_max
2310 # - for padding >9, the name format needs to be more distinctive
2311 pad_min, pad_max = 0, 1
2312 pad_values = [x for x in range(pad_min, pad_max + 1)]
2313 if error_name == ErrorIf.PadSmallerZero:
2314 pad_values = [x for x in range(-2, 0)]
2315 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2316 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2317
2318 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
2319 pad_const_int = testGen.getRandNumberDType(dtype)
2320 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00002321 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002322 pad_const_int = 0
2323 pad_const_fp = testGen.getRandNumberDType(dtype)
2324 else:
2325 return []
2326
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002327 list_shape_pad_values = list(shape_pad_values)
2328 # If we are producing tests for rank 6 or greater use sparsity
2329 if len(list_shape_pad_values) > 1024:
2330 sparsity_factor = 2 if error_name else 120
2331 sparsity = TosaArgGen._calculate_sparsity(
2332 len(list_shape_pad_values), sparsity_factor
2333 )
2334 else:
2335 sparsity = 1
2336
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002337 # Build arg list
2338 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002339 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002340 paddings = list(paddings)
2341 args_valid = True
2342
2343 if error_name == ErrorIf.PadSmallerZero:
2344 # Prevent negative output shapes while ensuring still testing for negative padding
2345 for i in range(rank):
2346 dim_after_padding = (
2347 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2348 )
2349 if dim_after_padding < 1:
2350 paddings[i] = (0, 0)
2351 if all([p > -1 for p in paddings[i]]):
2352 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002353 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01002354 name = "pad"
2355 for r in range(rank):
2356 before, after = paddings[r]
2357 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002358 args_dict = {
2359 "pad": np.array(paddings),
2360 "pad_const_int": pad_const_int,
2361 "pad_const_fp": pad_const_fp,
2362 }
2363 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002364
2365 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
2366 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002367
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002368 arg_list = TosaArgGen._add_data_generators(
2369 testGen,
2370 opName,
2371 dtype,
2372 arg_list,
2373 error_name,
2374 )
2375
2376 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002377 return arg_list
2378
2379 @staticmethod
2380 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
2381 arg_list = []
2382
2383 shape = shapeList[0]
2384 if error_name != ErrorIf.WrongRank:
2385 assert len(shape) == 4
2386
Jeremy Johnson0c716862023-04-13 17:18:19 +01002387 test_level8k = testGen.args.level8k and error_name is None
2388
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002389 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002390 startKernel = 2
2391 startPad = 0
2392 if not test_level8k:
2393 # Generate comprehensive argument lists
2394 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2395 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2396 # Stride must be greater than 1 to force non-integer error
2397 s_vals = [
2398 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2399 ]
2400 strides = {x for x in itertools.product(*([s_vals] * 2))}
2401 k_vals = [
2402 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2403 ]
2404 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2405 max_dim_size = None
2406 else:
2407 # Only test 8k levels
2408 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2409 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2410 strides = {(1, bigStride), (bigStride, 4)}
2411 kernels = {(1, bigKernel), (bigKernel, 3)}
2412 paddings = set()
2413 for s in sorted(list(strides)):
2414 for k in sorted(list(kernels)):
2415 padding = []
2416 for idx in range(len(k)):
2417 total_padding = s[idx] - shape[idx + 1] + k[idx]
2418 while total_padding < 0:
2419 # Must meet: shape + padding > kernel
2420 total_padding += s[idx]
2421 if total_padding < k[idx]:
2422 padding.extend([0, total_padding])
2423 else:
2424 # Note this may produce padding >= k[idx] which is not
2425 # allowed - but will be ignored in the creation loop below
2426 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2427 paddings.add(tuple(padding))
2428 # Create a limit for the output dimensions size
2429 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002430
James Ward8b390432022-08-12 20:48:56 +01002431 if opName == "max_pool2d":
2432 accum_dtypes = [None] # max_pool has no accumulate dtype
2433 elif dtype == DType.INT8 or dtype == DType.INT16:
2434 accum_dtypes = [DType.INT32]
2435 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002436 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002437 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002438 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01002439 elif error_name is None:
2440 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2441 else:
2442 # Set to something for the ErrorIf case which has
2443 # incorrect input data-type
2444 accum_dtypes = [DType.INT32]
2445
Jeremy Johnson0c716862023-04-13 17:18:19 +01002446 if not test_level8k:
2447 if testGen.args.oversize:
2448 # add some oversize argument values
2449 bigStride = 7
2450 bigKernel = 9
2451 strides.update(
2452 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002453 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002454 kernels.update(
2455 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2456 )
2457 if max(shape) < 64:
2458 # padding must be less than the kernel size
2459 bigPadding = bigKernel - 1
2460 paddings.update(
2461 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2462 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002463
Jeremy Johnson0c716862023-04-13 17:18:19 +01002464 # There are too many parameter combinations, so generate them sparsely,
2465 # very sparse for negative tests
2466 sparsity_factor = 2 if error_name else 500
2467 sparsity = (
2468 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2469 )
2470 else:
2471 # We have already limited test output combinations for 8k tests
2472 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002473
James Ward8b390432022-08-12 20:48:56 +01002474 arg_str = (
2475 "acc{}_st{}_kern{}_pad{}"
2476 if accum_dtypes[0] is not None
2477 else "st{}_kern{}_pad{}"
2478 )
2479
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002480 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002481 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002482 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002483
2484 # Support for larger values than 9 needs different delimiter
2485 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002486 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002487 delim.join([str(x) for x in stride]),
2488 delim.join([str(x) for x in kern]),
2489 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002490 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002491 args_dict = {
2492 "stride": stride,
2493 "pad": pad,
2494 "kernel": kern,
2495 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002496 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002497 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2498 }
James Ward8b390432022-08-12 20:48:56 +01002499
2500 if accum is not None:
2501 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002502 args_dict["acc_type"] = accum
2503 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002504
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002505 n = 0
James Ward8b390432022-08-12 20:48:56 +01002506 for a in accum_dtypes:
2507 for s in sorted(list(strides)):
2508 for p in sorted(list(paddings)):
2509 for k in sorted(list(kernels)):
2510 if error_name in [
2511 ErrorIf.StrideSmallerOne,
2512 ErrorIf.KernelSmallerOne,
2513 ErrorIf.PadSmallerZero,
2514 ErrorIf.PadLargerEqualKernel,
2515 ]:
2516 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
2517 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002518 )
James Ward8b390432022-08-12 20:48:56 +01002519 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002520 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002521 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002522 )
James Ward8b390432022-08-12 20:48:56 +01002523 elif (
2524 n % sparsity == 0
2525 # padding must not exceed the kernel size
2526 and p[0] < k[0]
2527 and p[1] < k[0]
2528 and p[2] < k[1]
2529 and p[3] < k[1]
2530 # the padded shape must exceed the kernel size
2531 and (shape[1] + p[0] + p[1]) > k[0]
2532 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002533 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002534 partial_h = shape[1] + p[0] + p[1] - k[0]
2535 partial_w = shape[2] + p[2] + p[3] - k[1]
2536 remainder_h = partial_h % s[0]
2537 remainder_w = partial_w % s[1]
2538 output_h = partial_h // s[0] + 1
2539 output_w = partial_w // s[1] + 1
2540 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01002541 if (
2542 # the parameters must produce integer exact output
2543 error_name != ErrorIf.PoolingOutputShapeNonInteger
2544 and remainder_h == 0
2545 and remainder_w == 0
2546 ) or (
2547 error_name == ErrorIf.PoolingOutputShapeNonInteger
2548 and (remainder_h != 0 or remainder_w != 0)
2549 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002550 if (
2551 max_dim_size is not None
2552 and max(output_h, output_w) > max_dim_size
2553 ):
2554 # Test will consume too much memory - skip it
2555 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002556 # Dot products = N*OH*OW*C
2557 dp = gtu.product(
2558 (shape[0], output_h, output_w, shape[3])
2559 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002560 arg_list.append(
2561 get_arg_list_element(a, s, p, k, dp, shape)
2562 )
James Ward8b390432022-08-12 20:48:56 +01002563 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002564
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002565 # Now add data generator types
2566 arg_list = TosaArgGen._add_data_generators(
2567 testGen,
2568 opName,
2569 dtype,
2570 arg_list,
2571 error_name,
2572 )
2573
2574 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002575 return arg_list
2576
2577 @staticmethod
2578 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
2579 arg_list = []
2580
2581 # Enumerate the output types here
2582 if error_name == ErrorIf.WrongOutputType:
2583 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
2584 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002585 dtypeList = [
2586 DType.BOOL,
2587 DType.INT16,
2588 DType.INT32,
2589 DType.FP16,
2590 DType.BF16,
2591 DType.FP32,
2592 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002593 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002594 dtypeList = [
2595 DType.BOOL,
2596 DType.INT8,
2597 DType.INT32,
2598 DType.FP16,
2599 DType.BF16,
2600 DType.FP32,
2601 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002602 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002603 dtypeList = [
2604 DType.BOOL,
2605 DType.INT8,
2606 DType.INT16,
2607 DType.FP16,
2608 DType.BF16,
2609 DType.FP32,
2610 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002611 elif inDtype == DType.BOOL:
2612 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002613 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00002614 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002615 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00002616 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002617 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00002618 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002619 elif error_name == ErrorIf.WrongInputType:
2620 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002621 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002622 else:
2623 raise Exception("Unexpected input dtype: {}".format(inDtype))
2624
2625 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002626 arg_list.append(
2627 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2628 )
2629
2630 # Now add data generator types
2631 arg_list = TosaArgGen._add_data_generators(
2632 testGen,
2633 opName,
2634 dtype,
2635 arg_list,
2636 error_name,
2637 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002638
2639 return arg_list
2640
2641 @staticmethod
2642 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
2643 arg_list = []
2644
2645 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002646 for outDtype in [
2647 DType.UINT8,
2648 DType.INT8,
2649 DType.INT16,
2650 DType.INT32,
2651 DType.UINT16,
2652 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002653 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002654 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002655 and error_name == ErrorIf.OutputZeroPointNotZero
2656 ):
2657 continue
2658 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002659 outDtype != DType.UINT16
2660 and error_name == ErrorIf.U16OutputZeroPointNotValid
2661 ) or (
2662 inDtype != DType.UINT16
2663 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002664 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002665 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002666 continue
2667 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002668 inDtype == DType.UINT8
2669 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002670 and error_name != ErrorIf.WrongOutputType
2671 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002672 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2673 continue
2674 if (
2675 inDtype not in [DType.INT8, DType.INT16]
2676 and outDtype == DType.UINT8
2677 and error_name != ErrorIf.WrongOutputType
2678 ):
2679 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2680 continue
2681 if (
2682 inDtype == DType.UINT16
2683 and outDtype != DType.INT16
2684 and error_name != ErrorIf.WrongOutputType
2685 ):
2686 # The only output dtype for UINT16 is INT16, skip all others
2687 continue
2688 if (
2689 inDtype != DType.INT16
2690 and outDtype == DType.UINT16
2691 and error_name != ErrorIf.WrongOutputType
2692 ):
2693 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002694 continue
2695 if (
2696 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002697 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002698 ):
2699 continue
2700
2701 for scale32 in [False, True]:
2702 if error_name == ErrorIf.ScaleTrue and not scale32:
2703 continue
2704 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2705 continue
2706 for double_round in [False, True]:
2707 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2708 continue
2709 for per_channel in [False, True]:
2710
2711 if (
2712 inDtype == DType.INT48
2713 and scale32
2714 and error_name != ErrorIf.ScaleTrue
2715 ):
2716 # Illegal condition. Must be scale32=False
2717 continue
2718 if (
2719 double_round
2720 and not scale32
2721 and error_name != ErrorIf.ScaleNotTrue
2722 ):
2723 # Illegal condition. ERROR_IF(!scale32 && double_round)
2724 continue
2725
2726 arg_list.append(
2727 (
2728 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002729 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002730 int(scale32),
2731 int(double_round),
2732 int(per_channel),
2733 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002734 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002735 )
2736 )
2737
2738 return arg_list
2739
2740 @staticmethod
2741 def agMul(testGen, opName, shapeList, dtype, error_name=None):
2742 arg_list = []
2743
2744 if dtype is DType.INT32:
2745 for p in range(testGen.args.num_rand_permutations):
2746
2747 shift = testGen.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002748 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002749 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002750 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002751
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002752 arg_list = TosaArgGen._add_data_generators(
2753 testGen,
2754 opName,
2755 dtype,
2756 arg_list,
2757 error_name,
2758 )
2759 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002760 return arg_list
2761
2762 @staticmethod
2763 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
2764 arg_list = []
2765
2766 arg_list.append(("roundTrue", [True]))
2767 arg_list.append(("roundFalse", [False]))
2768
2769 return arg_list
2770
Luke Hutton57287132023-02-06 14:54:18 +00002771 @staticmethod
2772 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2773 arg_list = []
2774
2775 arg_list.append(("inverseTrue", [True]))
2776 arg_list.append(("inverseFalse", [False]))
2777
2778 return arg_list
2779
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002780 # Helper function for reshape. Gets some factors of a larger number.
2781 @staticmethod
2782 def getFactors(val, start=1):
2783 factors = []
2784
2785 for i in range(start, int(np.sqrt(val)) + 1):
2786 if (val % i) == 0:
2787 factors.append(i)
2788
2789 return factors
2790
2791 @staticmethod
2792 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
2793 arg_list = []
2794
2795 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002796 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002797 factors = TosaArgGen.getFactors(totalElements)
2798
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002799 # Find new shapes up to the number of permutations asked for
2800 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002801 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002802 # Rank from 1 to TOSA_TENSOR_MAX_RANK
2803 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002804 if len(factors) < newRank:
2805 continue
2806
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002807 # escape_counter limits the generation of new shapes to a reasonable time
2808 for escape_counter in range(100):
2809
2810 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002811 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002812 remainingElements = totalElements
2813 shuffledFactors = testGen.rng.permutation(factors)
2814 for i in range(1, newRank):
2815 # pick rank-1 factors
2816 newShape.append(shuffledFactors[0])
2817 remainingElements = remainingElements // shuffledFactors[0]
2818 shuffledFactors = testGen.rng.permutation(
2819 TosaArgGen.getFactors(remainingElements)
2820 )
2821 newShape.append(remainingElements)
2822
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002823 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002824 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002825 for name, args_dict in arg_list:
2826 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002827 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002828 break
2829
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002830 if not duplicate:
2831 outShape = "x".join([str(x) for x in newShape])
2832 arg_list.append(
2833 (
2834 "perm{}_rank{}_out{}".format(p, newRank, outShape),
2835 {"new_shape": newShape},
2836 )
2837 )
2838 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002839 break
2840
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002841 # Now add data generator types
2842 arg_list = TosaArgGen._add_data_generators(
2843 testGen,
2844 opName,
2845 dtype,
2846 arg_list,
2847 error_name,
2848 )
2849
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002850 return arg_list
2851
2852 @staticmethod
2853 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
2854 arg_list = []
2855
2856 ifm_shape = shapeList[0]
2857
2858 if error_name == ErrorIf.IndexOutsideBounds:
2859 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
2860 incorrect_small_index = range(-len(ifm_shape), 0)
2861 permutations = [p for p in itertools.permutations(incorrect_large_index)]
2862 permutations.extend(
2863 [p for p in itertools.permutations(incorrect_small_index)]
2864 )
2865 elif error_name == ErrorIf.IndexUsedTwice:
2866 # Create list with a duplicated index
2867 perm_range = list(range(len(ifm_shape)))
2868 index_choice = testGen.rng.choice(range(len(perm_range)))
2869 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
2870 permutations = [p for p in itertools.permutations(perm_range)]
2871
2872 else:
2873 # Get all permutations
2874 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
2875
2876 # Limit to possible permutations from shape dimension or argument setting
2877 limit = min(len(permutations), testGen.args.num_rand_permutations)
2878
2879 # Get random permutation generator that uses all permutations
2880 random_permutations = testGen.rng.permutation(permutations)
2881
2882 # Create list of required amount of permutations
2883 arg_list = [
2884 ("perm{}".format(p), [random_permutations[p].tolist()])
2885 for p in range(limit)
2886 ]
2887 return arg_list
2888
2889 @staticmethod
2890 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
2891 arg_list = []
2892
2893 ifm_shape = shapeList[0]
2894 rank = len(ifm_shape)
2895
2896 for p in range(testGen.args.num_rand_permutations):
2897 start = []
2898 size = []
2899
2900 valid = True
2901
2902 for i in range(rank):
2903 if ifm_shape[i] > 1:
2904 start.append(testGen.randInt(0, ifm_shape[i]))
2905 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
2906
2907 # Invalid slice size?
2908 if size[i] == 0:
2909 valid = False
2910 else:
2911 start.append(0)
2912 size.append(1)
2913
2914 if valid:
2915 # If ERROR_IF test required then incorrect start, size will be returned
2916 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
2917 testGen, error_name, ifm_shape, start, size
2918 )
evacha017f7d4252024-01-24 12:08:09 +00002919 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
2920 # Now add data generator types
2921 arg_list = TosaArgGen._add_data_generators(
2922 testGen,
2923 opName,
2924 dtype,
2925 arg_list,
2926 error_name,
2927 )
2928 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002929 return arg_list
2930
2931 @staticmethod
2932 def agTile(testGen, opName, shapeList, dtype, error_name=None):
2933 arg_list = []
2934
2935 ifm_shape = shapeList[0]
2936 rank = len(ifm_shape)
2937
2938 for p in range(testGen.args.num_rand_permutations):
2939
2940 # Pick a few random, but small multiple values
2941 # because otherwise this has a tendency to generate
2942 # enormous tensors
2943 multiples = []
2944 for i in range(rank):
2945 if ifm_shape[i] > 1000:
2946 # Multiple of 1 if ifm_shape dimension is large to reduce
2947 # tensor size
2948 multiples.append(1)
2949 elif max(ifm_shape) > 1000:
2950 multiples.append(2)
2951 else:
2952 multiples.append(testGen.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00002953 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002954
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00002955 # Now add data generator types
2956 arg_list = TosaArgGen._add_data_generators(
2957 testGen,
2958 opName,
2959 dtype,
2960 arg_list,
2961 error_name,
2962 )
2963 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002964 return arg_list
2965
2966 @staticmethod
2967 def agResize(testGen, opName, shapeList, dtype, error_name=None):
2968 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002969 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002970
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002971 def get_aspect_ratio_resize_params():
2972 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
2973 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
2974 invert = testGen.rng.choice((False, True))
2975 letterbox = testGen.rng.choice((False, True))
2976
2977 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
2978 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
2979 scale_y_d = scale_x_d = 1
2980 offset_x = offset_y = 0
2981
2982 if letterbox:
2983 max_border = scale_y_n
2984 border_y = testGen.randInt(low=0, high=max_border)
2985 border_x = 0
2986 else:
2987 # Pillarboxing
2988 border_y = 0
2989 max_border = scale_x_n
2990 border_x = testGen.randInt(low=0, high=max_border)
2991
2992 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2993 offset = (offset_y, offset_x)
2994 border = (border_y, border_x)
2995
2996 return scale, offset, border
2997
2998 def get_upscale_downscale_params():
2999 valid_params = False
3000 while not valid_params:
3001 upscale = testGen.rng.choice((False, True))
3002
3003 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
3004 origin_sampling = testGen.rng.choice((False, True))
3005
3006 if upscale:
3007 shift = testGen.randInt(low=1, high=4)
3008 scale_x_d = scale_y_d = 1
3009 scale_x_n = scale_y_n = (
3010 1 << shift if origin_sampling else 2 << shift
3011 )
3012 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3013 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3014 else:
3015 scale_x_n = 1
3016 scale_y_n = 1
3017
3018 # Return list of valid scale_*_d values (max value 4) given input dim shape
3019 def get_valid_denom(ifm_dim):
3020 return [x for x in range(1, 5) if ifm_dim % x == 1]
3021
3022 # Generate list of valid downscale values and choose one randomly
3023 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3024 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3025
3026 if not valid_scale_y_ds and not valid_scale_x_ds:
3027 # Bad parameters, skip
3028 continue
3029
3030 if not valid_scale_y_ds:
3031 scale_y_d = 1
3032 else:
3033 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
3034
3035 if not valid_scale_x_ds:
3036 scale_x_d = 1
3037 else:
3038 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
3039
3040 border_x = border_y = 0
3041 offset_y = testGen.randInt(0, 16 * scale_y_n)
3042 offset_x = testGen.randInt(0, 16 * scale_x_n)
3043 valid_params = True
3044
3045 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3046 offset = (offset_y, offset_x)
3047 border = (border_y, border_x)
3048 return scale, offset, border
3049
3050 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003051 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3052 scale = scale_n / scale_d
3053 if scale > max_scale:
3054 factor = scale / max_scale
3055 new_scale_d = math.ceil(scale_d * factor)
3056 assert scale_n / new_scale_d <= max_scale
3057 scale_d = new_scale_d
3058 return scale_d
3059
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003060 # Scale
3061 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
3062 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
3063
3064 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
3065 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
3066
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003067 scale_y_d = fix_scale_to_max_scale(
3068 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3069 )
3070 scale_x_d = fix_scale_to_max_scale(
3071 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3072 )
3073
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003074 # Offsets and border within the scale
3075 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3076 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3077 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3078 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
3079
3080 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3081 offset = (offset_y, offset_x)
3082 border = (border_y, border_x)
3083 return scale, offset, border
3084
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003085 def get_level_8k_params():
3086 # Create 64x scale - 64/1 to 2048/32
3087 scale_d = testGen.randInt(
3088 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3089 )
3090 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3091 # Create half to fifth scaling
3092 scale_d_alt = testGen.randInt(low=2, high=6)
3093 scale_n_alt = 1
3094 switch = testGen.rng.choice((False, True))
3095 if switch:
3096 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3097 else:
3098 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3099
3100 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3101 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
3102 offset = (offset_y, offset_x)
3103 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
3104 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
3105 border = (border_y, border_x)
3106 return scale, offset, border
3107
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003108 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003109 # Exclude illegal {mode, type} configurations. Pick legal output types
3110 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3111 outputDTypeList = [DType.INT8]
3112 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3113 outputDTypeList = [DType.INT16]
3114 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3115 outputDTypeList = [DType.INT32]
3116 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3117 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003118 elif dtype == DType.FP16:
3119 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003120 elif dtype == DType.BF16:
3121 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003122 elif dtype == DType.FP32:
3123 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003124 elif error_name == ErrorIf.WrongInputType:
3125 # If an incorrect input type is used then we set a 'correct'
3126 # output type to avoid other errors
3127 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3128 else:
3129 continue
3130
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003131 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3132
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003133 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003134 perm = 0
3135 while perm < testGen.args.num_rand_permutations:
3136 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003137 if not testGen.args.level8k:
3138 _rnd_param_fn = testGen.rng.choice(
3139 (
3140 get_rand_params,
3141 get_upscale_downscale_params,
3142 get_aspect_ratio_resize_params,
3143 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003144 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003145 scale, offset, border = _rnd_param_fn()
3146 else:
3147 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003148
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003149 # Expand params for bounds-checking
3150 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3151 (offset_y, offset_x) = offset
3152 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003153
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003154 # Make sure output dimensions OH and OW are integers
3155 partial_output_y = (
3156 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3157 )
3158 partial_output_x = (
3159 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3160 )
3161 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003162 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003163 if (
3164 partial_output_y % scale_y_d == 0
3165 and partial_output_x % scale_x_d == 0
3166 ):
3167 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003168 if perm > 0:
3169 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003170 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003171 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003172 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003173 while partial_output_y % scale_y_d != 0:
3174 scale_y_d -= 1
3175 while partial_output_x % scale_x_d != 0:
3176 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003177 # Make sure we are still within max scaling
3178 if (
3179 scale_y_n / scale_y_d
3180 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3181 scale_x_n / scale_x_d
3182 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3183 # Skip the test as it is using too large a scaling factor
3184 if perm > 0:
3185 perm += 1
3186 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003187
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003188 output_y = partial_output_y // scale_y_d + 1
3189 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003190
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003191 if (
3192 output_y >= testGen.args.max_resize_output_dim
3193 or output_x >= testGen.args.max_resize_output_dim
3194 ) and error_name is None:
3195 # Skip positive test if output dim will be too high
3196 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003197 if not testGen.args.level8k or perm > 0:
3198 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003199 continue
3200
3201 if (
3202 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003203 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003204 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003205 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003206 ):
3207 # Output dimensions out of scope
3208 if error_name is not None and perm > 0:
3209 # As long as we have one ERROR_IF test, don't worry
3210 # about creating all the other permutations
3211 perm += 1
3212 continue
3213
3214 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3215 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003216 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003217 and output_y - scale_y_d < 1
3218 )
3219 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003220 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003221 and output_x - scale_x_d < 1
3222 )
3223 ):
3224 # Can't create a negative test with these params as it
3225 # will create invalid output size
3226 if perm > 0:
3227 perm += 1
3228 continue
3229
3230 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3231 offset = [offset_y, offset_x]
3232 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003233
3234 # Common for all data types
3235 if error_name is not None:
3236 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003237 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003238 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003239 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003240 outputDTypeNew,
3241 ) = TosaErrorIfArgGen.eiResizeErrorIf(
3242 testGen,
3243 error_name,
3244 mode,
3245 dtype,
3246 shapeList,
3247 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003248 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003249 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003250 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003251 )
3252 else:
3253 outputDTypeNew = outputDType
3254
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003255 arg_to_append = (
3256 arg_str.format(
3257 "N" if mode == ResizeMode.NEAREST else "B",
3258 testGen.typeStr(outputDTypeNew),
3259 scale[0],
3260 scale[1],
3261 scale[2],
3262 scale[3],
3263 offset[0],
3264 offset[1],
3265 border[0],
3266 border[1],
3267 ),
3268 [
3269 mode,
3270 scale,
3271 offset,
3272 border,
3273 dtype,
3274 outputDTypeNew,
3275 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003276 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003277 if arg_to_append in arg_list:
3278 # Skip already generated test params
3279 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003280
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003281 # Valid permutation
3282 perm += 1
3283 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003284 return arg_list
3285
3286 @staticmethod
3287 def agTable(testGen, opName, shapeList, dtype, error_name=None):
3288 arg_list = []
3289
3290 if dtype == DType.INT8:
3291 table = np.int32(
3292 testGen.rng.integers(low=-128, high=128, size=[256])
3293 ).tolist()
3294 else: # INT16
3295 table = np.int32(
3296 testGen.rng.integers(low=-32768, high=32768, size=[513])
3297 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003298 # Make sure all slopes are within REQUIRE min/max 16-bit int
3299 for idx in range(len(table) - 1):
3300 slope = table[idx + 1] - table[idx]
3301 # Alter the next table entry to force the slope to be ok
3302 if slope > 32767:
3303 table[idx + 1] -= slope - 32767
3304 if slope < -32768:
3305 table[idx + 1] -= slope + 32768
3306 slope = table[idx + 1] - table[idx]
3307 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003308 arg_list.append(
3309 (
3310 "",
3311 [table],
3312 )
3313 )
3314 return arg_list
3315
3316 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
3317 # CondIf generates the condition values here.
3318 # Convert to tensors in the build function, along with the
3319 # then and else blocks
3320 arg_list = []
3321
3322 for c in [False, True]:
3323 arg_list.append(("cond{}".format(int(c)), [c]))
3324
3325 return arg_list
3326
3327 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
3328 # While loop: 0 iterations, 1, more than 1
3329 arg_list = []
3330
3331 for iter in [0, 1, 4]:
3332 arg_list.append(("iter{}".format(iter), [iter]))
3333
3334 return arg_list