blob: 8501caace5903e1d299813917bc43c7860c2af7d [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
19
20class TosaQuantGen:
21 """QuantizedInfo random generator helper functions.
22
23 Specify with 'qgen': in the operator defintion.
24 """
25
26 def __init__(self):
27 pass
28
29 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010031
32 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010033 if testGen.args.zeropoint is not None:
34 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035 return testGen.randInt(-128, 128)
36 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010037 if testGen.args.zeropoint is not None:
38 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 return testGen.randInt(0, 256)
40 elif error_name in [
41 ErrorIf.InputZeroPointNotZero,
42 ErrorIf.WeightZeroPointNotZero,
43 ErrorIf.OutputZeroPointNotZero,
44 ]:
45 zero_point = testGen.randInt(-128, 128)
46 if zero_point == 0:
47 zero_point = 1
48 return zero_point
49 return 0
50
51 @staticmethod
52 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010053 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000054 qinfo = [
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 TosaQuantGen.getZeroPoint(testGen, dtype),
57 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010058 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000059 qinfo = [
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
62 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010063 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000064 qinfo = [
65 TosaQuantGen.getZeroPoint(testGen, dtype),
66 TosaQuantGen.getZeroPoint(testGen, dtype),
67 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068 return qinfo
69
70 @staticmethod
71 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010072 if isinstance(dtype_or_dtypeList, list):
73 # a list of [input, weights, accumulator] dtypes
74 dtypeList = dtype_or_dtypeList
75 else:
76 # an int, [input, weights, accumulator] dtypes are the same
77 dtypeList = [dtype_or_dtypeList] * 3
78
79 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000080 qinfo = [
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
82 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
83 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000085 qinfo = [
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
87 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
88 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000090 qinfo = [
91 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
92 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
93 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 return qinfo
95
96 @staticmethod
97 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000099 qinfo = [
100 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
101 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
102 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000104 qinfo = [
105 TosaQuantGen.getZeroPoint(testGen, dtype),
106 TosaQuantGen.getZeroPoint(testGen, dtype),
107 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108 return qinfo
109
110 @staticmethod
111 def computeMultiplierAndShift(scaleFp, scale32):
112 # Derived from computeMultiplierAndShiftTosaScale32
113 # Provide a floating-point scaling factor and the scale32 parameter
114 # to compute the multiplier and shift
115
116 if scale32:
117 scaleBits = 31
118 else:
119 scaleBits = 15
120
121 m, shift = math.frexp(scaleFp)
122
123 if scaleFp < 0.0:
124 m = -m
125
126 multiplier = round(m * (1 << scaleBits))
127 assert multiplier <= (1 << scaleBits)
128
129 if multiplier == (1 << scaleBits):
130 multiplier = multiplier // 2
131 shift = shift + 1
132
133 shift = (-shift) + scaleBits
134 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
135 # scaleFp, scaleBits, m, multiplier, shift))
136
137 # Adjust multiplier such that shift is in allowed value range.
138 if shift == 0:
139 multiplier = multiplier // 4
140 shift = shift + 2
141 elif shift == 1:
142 multiplier = multiplier // 2
143 shift = shift + 1
144 elif shift == 63:
145 multiplier = multiplier * 2
146 shift = shift - 1
147
148 assert multiplier <= (1 << scaleBits)
149 assert shift >= 2 and shift <= 62
150
151 return multiplier, shift
152
153
154class TosaTensorGen:
155 """Tensor generators create a shape list for the placeholder and const tensor
156 data operands for the operator.
157
158 The actual random data is generated separately for each test.
159 """
160
161 def __init__(self):
162 pass
163
164 @staticmethod
165 def tgBasic(testGen, opName, rank, error_name=None):
166 pl, const = opName["operands"]
167 shape = testGen.makeShape(rank)
168
169 # Constrict the overall size of the shape when creating ERROR_IF tests
170 if error_name:
171 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
172
173 shape_list = []
174 for i in range(pl + const):
175 shape_list.append(shape.copy())
176
Luke Huttona4e48ca2023-02-22 11:53:48 +0000177 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 if error_name == ErrorIf.RankMismatch:
179 if rank == 1 and i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
181 elif i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
183
184 return shape_list
185
186 @staticmethod
187 def tgNHWC(testGen, opName, rank, error_name=None):
188 pl, const = opName["operands"]
189
190 if error_name != ErrorIf.WrongRank:
191 assert rank == 4
192
193 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000194 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name and error_name != ErrorIf.MaxDimExceeded:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000207 def tgGather(testGen, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100208 pl, const = opName["operands"]
209
210 assert pl == 2
211 assert const == 0
212 if error_name != ErrorIf.WrongRank:
213 assert rank == 3
214
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000215 values_shape = testGen.makeShape(rank)
216 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100217
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000218 N = values_shape[0]
219 W = testGen.makeDimension()
220 indices_shape = [N, W]
221
222 shape_list = [values_shape, indices_shape]
223 return shape_list
224
225 @staticmethod
226 def tgScatter(testGen, opName, rank, error_name=None):
227 pl, const = opName["operands"]
228
229 assert pl == 3
230 assert const == 0
231 if error_name != ErrorIf.WrongRank:
232 assert rank == 3
233
234 values_in_shape = testGen.makeShape(rank)
235 values_in_shape = testGen.constrictBatchSize(values_in_shape)
236
237 N = values_in_shape[0]
238 K = values_in_shape[1]
239 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100240
Jeremy Johnson194fe312023-12-07 14:17:57 +0000241 # Make sure W is not greater than K, as we can only write each output index
242 # once (having a W greater than K means that you have to repeat a K index)
243 W_min = min(testGen.args.tensor_shape_range[0], K)
244 W_max = min(testGen.args.tensor_shape_range[1], K)
245 W = testGen.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100246
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000247 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100248
249 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000250 shape_list.append(values_in_shape)
251 shape_list.append([N, W]) # indices
252 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100253
254 return shape_list
255
256 @staticmethod
257 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
258 shape = testGen.makeShape(rank)
259
260 pl, const = op["operands"]
261
262 shape_list = []
263
264 # Choose one of the inputs to broadcast
265 # Note: Simplifies OutputShaper code if we don't change first shape for errors
266 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Jerry Ge135c9552023-05-23 20:59:32 +0000267 fuzz_idx = testGen.randInt(0, rank)
268
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100269 for i in range(pl + const):
270 shape_bcast = shape.copy()
271
Jerry Ge135c9552023-05-23 20:59:32 +0000272 # To test broadcasting, the chosen fuzz index dimension should not be 1
273 if shape_bcast[fuzz_idx] == 1:
274 shape_bcast[fuzz_idx] += 1
275
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100276 # If the chosen input, pick a random index to broadcast
277 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000278 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100279 # Add one rank to the shape (or more for rank of 1)
280 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
281 shape_bcast = np.concatenate(
282 (shape_bcast, testGen.makeShape(extra_ranks))
283 )
284 if rank != 1:
285 # Either keep the extra rank, or remove it
286 new_len = testGen.rng.choice([-2, len(shape_bcast)])
287 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000288 elif error_name == ErrorIf.BroadcastShapesMismatch:
289 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100290 else:
291 shape_bcast[fuzz_idx] = 1
292
293 shape_list.append(shape_bcast)
294
295 return shape_list
296
297 @staticmethod
298 def tgConv2D(testGen, op, rank, error_name=None):
299 pl, const = op["operands"]
300
301 if error_name != ErrorIf.WrongRank:
302 assert rank == 4
303
304 # IFM dimensions are NHWC
305 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000306 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100307
308 # Constrict the overall size of the shape when creating ERROR_IF tests
309 if error_name:
310 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
311 ifm_shape, max_dim=24, max_items=10000
312 )
313
314 # Get the filter height/width from the operator parameters
315 filter_hw = op["filter"]
316
317 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000318 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100319
320 # The filter dimensions are OHWI
321 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
322
323 # The bias is OC
324 bias_shape = np.asarray([ofm_depth])
325
326 return [ifm_shape, filter_shape, bias_shape]
327
328 @staticmethod
329 def tgConv3D(testGen, op, rank, error_name=None):
330 pl, const = op["operands"]
331
332 if error_name != ErrorIf.WrongRank:
333 assert rank == 5
334
335 # IFM dimensions are NDHWC
336 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000337 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100338
339 # Constrict the overall size of the shape when creating ERROR_IF tests
340 if error_name:
341 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
342 ifm_shape, max_dim=24, max_items=10000
343 )
344
345 # Get the filter depth/height/width from the operator parameters
346 filter_dhw = op["filter"]
347
348 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000349 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100350
351 # The filter dimensions are ODHWI
352 filter_shape = np.asarray(
353 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
354 )
355
356 # The bias is OC
357 bias_shape = np.asarray([ofm_channel])
358
359 return [ifm_shape, filter_shape, bias_shape]
360
361 @staticmethod
362 def tgTransposeConv2D(testGen, op, rank, error_name=None):
363 pl, const = op["operands"]
364
365 if error_name != ErrorIf.WrongRank:
366 assert rank == 4
367
368 # IFM dimensions are NHWC
369 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000370 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100371
372 # Constrict the overall size of the shape when creating ERROR_IF tests
373 if error_name:
374 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
375 ifm_shape, max_dim=24, max_items=10000
376 )
377
378 # Get the filter height/width from the operator parameters
379 filter_hw = op["filter"]
380
381 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000382 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100383
384 # The filter dimensions are OHWI
385 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
386
387 # The bias is OC
388 bias_shape = np.asarray([ofm_depth])
389
390 return [ifm_shape, filter_shape, bias_shape]
391
392 @staticmethod
393 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
394 pl, const = op["operands"]
395
396 if error_name != ErrorIf.WrongRank:
397 assert rank == 4
398 assert pl == 1 and const == 2
399
400 # IFM dimensions are NHWC
401 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000402 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100403
404 # Constrict the overall size of the shape when creating ERROR_IF tests
405 if error_name:
406 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
407 ifm_shape, max_dim=24, max_items=10000
408 )
409
410 # Get the filter height/width from the operator parameters
411 # Filter is KH, HW, C, M
412 filter_hw = op["filter"]
413
414 # Generate a random OFM depth, but don't let it get too big because
415 # the output depth is M * C
416 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000417 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100418 ) + 1
419
420 # The filter dimensions are HWCM
421 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
422
423 # The bias is M * C
424 bias_shape = np.asarray([ifm_shape[3] * filter_m])
425
426 return [ifm_shape, filter_shape, bias_shape]
427
428 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000429 def tgFFT2d(testGen, op, rank, error_name=None):
430 pl, const = op["operands"]
431
432 if error_name != ErrorIf.WrongRank:
433 assert rank == 3
434 assert pl == 2 and const == 0
435
436 # IFM dimensions are NHW
437 ifm_shape = testGen.makeShape(rank)
438
439 # Select nearest lower power of two from input height and width
440 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
441 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
442
443 # Constrict the overall size of the shape when creating ERROR_IF tests
444 if error_name:
445 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
446
447 # Generate an invalid kernel that is not a power of two
448 if error_name == ErrorIf.KernelNotPowerOfTwo:
449 inc_h = 2 if ifm_shape[1] == 1 else 1
450 inc_w = 2 if ifm_shape[2] == 1 else 1
451 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
452 selected_inc = testGen.rng.choice(inc_choices)
453 ifm_shape[1] += selected_inc[0]
454 ifm_shape[2] += selected_inc[1]
455
456 ifm_shape = testGen.constrictBatchSize(ifm_shape)
457
458 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
459 if error_name == ErrorIf.FFTInputShapeMismatch:
460 modify_shape = testGen.rng.choice([0, 1])
461 # Only modify kernel (H, W)
462 modify_dim = testGen.rng.choice([1, 2])
463 ifm_shapes[modify_shape][modify_dim] *= 2
464
465 return [ifm_shapes[0], ifm_shapes[1]]
466
467 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000468 def tgRFFT2d(testGen, op, rank, error_name=None):
469 pl, const = op["operands"]
470
471 if error_name != ErrorIf.WrongRank:
472 assert rank == 3
473 assert pl == 1 and const == 0
474
475 # IFM dimensions are NHW
476 ifm_shape = testGen.makeShape(rank)
477
478 # Select nearest lower power of two from input height and width
479 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
480 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
481
482 # Constrict the overall size of the shape when creating ERROR_IF tests
483 if error_name:
484 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
485
486 # Generate an invalid kernel that is not a power of two
487 if error_name == ErrorIf.KernelNotPowerOfTwo:
488 # We must increment by 2 if current size is 1
489 inc_h = 2 if ifm_shape[1] == 1 else 1
490 inc_w = 2 if ifm_shape[2] == 1 else 1
491 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
492 selected_inc = testGen.rng.choice(inc_choices)
493 ifm_shape[1] += selected_inc[0]
494 ifm_shape[2] += selected_inc[1]
495
James Ward30124a82023-02-02 14:56:33 +0000496 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000497
498 return [ifm_shape]
499
500 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100501 def tgFullyConnected(testGen, op, rank, error_name=None):
502 pl, const = op["operands"]
503
504 if error_name != ErrorIf.WrongRank:
505 assert rank == 2
506
507 input_shape = testGen.makeShape(rank)
508
509 # Constrict the overall size of the shape when creating ERROR_IF tests
510 if error_name:
511 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
512
513 filter_oc = testGen.rng.integers(
514 low=testGen.args.tensor_shape_range[0],
515 high=testGen.args.tensor_shape_range[1],
516 size=1,
517 )[0]
518 filter_shape = np.asarray([filter_oc, input_shape[1]])
519
520 bias_shape = np.asarray([filter_oc])
521
522 return [input_shape, filter_shape, bias_shape]
523
524 @staticmethod
525 def tgMatmul(testGen, op, rank, error_name=None):
526 pl, const = op["operands"]
527
528 if error_name != ErrorIf.WrongRank:
529 assert rank == 3
530 assert pl == 2 and const == 0
531
532 a_shape = testGen.makeShape(rank)
533
534 # Constrict the overall size of the shape when creating ERROR_IF tests
535 if error_name:
536 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
537
538 # Get a random number for b_oc even if target shape is defined
539 b_oc = np.int32(
540 testGen.rng.integers(
541 low=testGen.args.tensor_shape_range[0],
542 high=testGen.args.tensor_shape_range[1],
543 size=1,
544 )
545 )[0]
546 # If N or H is large let b_oc be 1 to reduce output tensor size
547 if max(a_shape) > 1000:
548 b_oc = 1
549
550 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
551 return [a_shape, b_shape]
552
553 @staticmethod
554 def tgConcat(testGen, opName, rank, error_name=None):
555 pl, const = opName["operands"]
556 shape = testGen.makeShape(rank)
557
558 # Create extra tensors to concat.
559 # Take into account value of pl when getting maximum number of concats
560 num_tensors = testGen.randInt(0, 4)
561 shape_list = []
562 for i in range(pl + const + num_tensors):
563 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
564 remove = testGen.rng.choice([True, False])
565 wrongShape = shape.copy()
566
567 if remove and len(shape) > 1:
568 wrongShape = wrongShape[1:]
569 else:
570 wrongShape = list(wrongShape)
571 wrongShape.append(testGen.rng.integers(1, 10))
572
573 shape_list.append(wrongShape)
574 else:
575 shape_list.append(shape.copy())
576
577 return shape_list
578
579 @staticmethod
580 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
581 if error_name in [
582 ErrorIf.AxisSmallerZero,
583 ErrorIf.AxisLargerRank,
584 ErrorIf.ConcatInputRankMismatch,
585 ]:
586 return shapeList
587
588 # Split concat shape along axis to allow for multiple const inputs
589 # without making too many large tensors
590 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
591 # If axis can't be split we still need to invalidate other dimensions
592 if error_name == ErrorIf.ConcatInputDimMismatch:
593 for shape in shapeList[1:]:
594 # Negative test shapeLists are created individually for each test,
595 # so no need to copy the shape before altering it.
596 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
597 return shapeList
598
599 # Create copy of shape we are going to split (so we don't alter shapeList)
600 shape = shapeList[0].copy()
601 # Add original shape as first input
602 new_shapeList = [shape.copy()]
603 length_on_axis = shape[axis]
604 remaining_length = length_on_axis
605 for i in range(len(shapeList) - 2):
606 # Calculate split on axis and remaining value
607 split_shape_val = int(shape[axis] / 2)
608 remaining_length = remaining_length - split_shape_val
609
610 # Append new shape, and set remaining shape
611 shape[axis] = split_shape_val
612 new_shapeList.append(shape.copy())
613
614 # invalidate dimensions
615 if error_name == ErrorIf.ConcatInputDimMismatch:
616 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
617 else:
618 shape[axis] = remaining_length
619
620 if i == len(shapeList) - 3:
621 new_shapeList.append(shape.copy())
622
623 return new_shapeList
624
Won Jeon74342e52024-01-09 00:34:40 +0000625 @staticmethod
626 def tgShape(testGen, opName, rank, error_name=None):
627 pl, const = opName["operands"]
628 shape = [rank]
629
630 # Constrict the overall size of the shape when creating ERROR_IF tests
631 if error_name:
632 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
633
634 shape_list = []
635 for i in range(pl + const):
636 shape_list.append(shape.copy())
637
638 # Generates an input rank mismatch for operators with more than one input
639 if error_name == ErrorIf.RankMismatch:
640 if rank == 1 and i != 1:
641 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
642 elif i != 1:
643 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
644
645 return shape_list
646
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100647
648class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100649 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100650
651 def __init__(self):
652 pass
653
Jeremy Johnson1271c442023-09-05 11:39:26 +0100654 class TVGInfo:
655 """Enhanced tensor values information including data gen dict."""
656
657 def __init__(self, tensorList, dataGenDict):
658 self.tensorList = tensorList
659 self.dataGenDict = dataGenDict
660
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100661 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000662 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100663 pCount, cCount = op["operands"]
664
665 tens = []
666 tens.extend(
667 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
668 )
669 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
670
671 return tens
672
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100673 # Default high value for random numbers
674 TVG_FLOAT_HIGH_VALUE = {
675 DType.FP32: (1 << 128) - (1 << (127 - 23)),
676 DType.FP16: (1 << 16) - (1 << (15 - 10)),
677 DType.BF16: (1 << 128) - (1 << (127 - 7)),
678 }
679
Jeremy Johnson30476252023-11-20 16:15:30 +0000680 # Default lowest normal values for random numbers
681 TVG_FLOAT_LOW_VALUE = {
682 DType.FP32: np.exp2(-126),
683 DType.FP16: np.exp2(-14),
684 DType.BF16: np.exp2(-126),
685 }
686
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100687 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +0000688 def _get_data_range(testGen, dtype, highValueLookup, lowValueLookup=None):
689 # Return a tuple of (low,high) data range values for the given data
690 # type using a combination of per operator table limits, data limits
691 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000692 if dtype in highValueLookup:
Jeremy Johnson30476252023-11-20 16:15:30 +0000693 type_range = testGen.getDTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000694 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000695 if lowValueLookup is not None and dtype in lowValueLookup:
696 low_val = lowValueLookup[dtype]
697 else:
698 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000699 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000700 # respecting the default ranges if more/less than the low/high
701 # values
702 data_range = (
703 max(low_val, type_range[0]),
704 min(high_val, type_range[1]),
705 )
706 if data_range[0] > data_range[1]:
707 # Invalid data range from low to high created due to user
708 # constraints revert to using internal ranges as they are
709 # known to work
710 msg = f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
711 warnings.warn(msg)
712 data_range = (low_val, high_val)
713 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000714 return None
715
716 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100717 def tvgLazyGenDefault(
718 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
719 ):
720 # Variable inputs versus constants
721 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000722 if "p_count" in argsDict:
723 # Override for operators like CONCAT
724 pCount = argsDict["p_count"]
725 cCount = argsDict["c_count"]
726 assert pCount + cCount == len(
727 shapeList
728 ), "Placeholders & Constant tensors must match shapes list"
729
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000730 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100731
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100732 if (
733 error_name is not None
734 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100735 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100736 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000737 # Fall back to internal data gen when dealing with unsupported types or ops
738 data_range = argsDict["data_range"] if "data_range" in argsDict else None
739 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000740 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000741 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000742 if "data_range_list" in argsDict:
743 data_range = argsDict["data_range_list"][idx]["range"]
744 roundMode = (
745 "round" in argsDict["data_range_list"][idx]
746 and argsDict["data_range_list"][idx]["round"] is True
747 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000748 if data_range is not None and dtype not in (
749 DType.FP16,
750 DType.FP32,
751 DType.BF16,
752 ):
753 # Change from inclusive to exclusive range
754 data_range = (data_range[0], data_range[1] + 1)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000755 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000756
757 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
758 arr = np.int64(argsDict["fixed_data"][idx])
759 else:
760 arr = testGen.getRandTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000761 if roundMode:
762 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000763 if idx < pCount:
764 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
765 else:
766 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100767
Jeremy Johnson1271c442023-09-05 11:39:26 +0100768 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
769
770 # Create data generator meta-data
771 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100772 tens_data = {
773 "version": "0.1",
774 "tensors": {},
775 }
776 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100777 for idx, shape in enumerate(shapeList):
778
779 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000780 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
781 tens_meta["generator"] = gtu.DataGenType(
782 gtu.DataGenType.FIXED_DATA
783 ).name
784 else:
785 tens_meta["generator"] = gtu.DataGenType(dg_type).name
786
Jeremy Johnson1271c442023-09-05 11:39:26 +0100787 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
788 tens_meta["shape"] = [int(i) for i in shape]
789 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100790 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100791
792 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100793 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100794 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100795 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100796
797 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
798 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000799 if (
800 tens_meta["generator"]
801 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
802 ):
803 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
804 tens_meta["fixed_data_info"] = info
805 else:
806 # TODO - generate seed for this generator based on test
807 info["rng_seed"] = 42
Jeremy Johnson30476252023-11-20 16:15:30 +0000808
Won Jeon64e4bfe2024-01-18 06:31:55 +0000809 data_range = None
810 if "data_range_list" in argsDict:
811 data_range = argsDict["data_range_list"][idx]["range"]
812 if "round" in argsDict["data_range_list"][idx]:
813 info["round"] = argsDict["data_range_list"][idx]["round"]
814 elif "data_range" in argsDict:
815 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000816
Won Jeon64e4bfe2024-01-18 06:31:55 +0000817 if data_range is None:
818 data_range = testGen.getDTypeRange(
819 dtypeList[idx], high_inclusive=True
820 )
821 info["range"] = [str(v) for v in data_range]
822 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100823 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
824 info = {}
825 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100826 info["ks"] = int(argsDict["ks"])
827 if "acc_type" in argsDict:
828 # Convert type number into JSON name
829 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
830 "json"
831 ]
832 if "kernel" in argsDict:
833 info["kernel"] = [int(k) for k in argsDict["kernel"]]
834 if "axis" in argsDict:
835 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100836 tens_meta["dot_product_info"] = info
837 else:
838 # TODO - other data gen type
839 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100840
841 # Using the finished generate config meta data - generate the data if
842 # needed and assign a tensor name from the serializer
843
844 # Need to generate data when not lazy or for the bias tensor as we need
845 # to work out if the bias data is non-zero for compliance
846 if not testGen.args.lazy_data_gen or (
847 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
848 ):
849 # Give this tensor a temporary name until we get one from the serializer
850 temp_name = f"placeholder_{idx}"
851 dg_tens_meta[temp_name] = tens_meta
852 # Create data now using the temporary name to access meta details
853 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000854 if tens_meta["data_type"] == "SHAPE":
855 # Tensor type SHAPE and Numpy file type must be the same
856 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100857 # Remove the item as we will give it the correct name later
858 del dg_tens_meta[temp_name]
859
860 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
861 # The KS value used by compliance verification is altered when the
862 # bias data is non-zero
863 if max(abs(data)) > 0.0:
864 argsDict["ksb"] = argsDict["ks"] + 1
865
866 if testGen.args.lazy_data_gen:
867 data = None
868
869 if tens_meta["input_type"] == "VARIABLE":
870 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
871 else:
872 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
873
874 tens_ser_list.append(tens)
875 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100876 dg_tens_meta[tens.name] = tens_meta
877
Jeremy Johnson1271c442023-09-05 11:39:26 +0100878 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
879
880 @staticmethod
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000881 def tvgNegate(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100882 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000883 # Integer test
884 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100885 pCount, cCount = op["operands"]
886 assert (
887 pCount == 1 and cCount == 0
888 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100889 # Must create tensors with values within accumulator (int32) negatable
890 # range
891 max_val = (1 << 31) - 1
892 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100893 arr = np.int32(
894 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
895 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000896 tens_ser_list = []
897 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100898 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
899 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000900 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100901 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000902 # ERROR_IF or floating point test
903 return TosaTensorValuesGen.tvgLazyGenDefault(
904 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100905 )
906
Jeremy Johnson30476252023-11-20 16:15:30 +0000907 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000908 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
909 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
910 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
911 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
912 }
913
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100914 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000915 def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon74342e52024-01-09 00:34:40 +0000916 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000917 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100918 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000919 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100920 pCount, cCount = op["operands"]
921 assert (
922 pCount == 2 and cCount == 0
923 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000924 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000925 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
926 data_range = testGen.args.tensor_shape_range
927 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0], data_range)
928 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100929 if add:
930 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
931 else:
932 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
933
934 # Work out the saturation limits
935 max_i32 = (1 << 31) - 1
936 min_i32 = -(1 << 31)
937 max_arr = np.full(shapeList[1], max_i32)
938 min_arr = np.full(shapeList[1], min_i32)
939
940 # Find how much values exceed the maximum/minimums
941 sat_max_arr = np.maximum(res_arr - max_arr, 0)
942 sat_min_arr = np.minimum(res_arr - min_arr, 0)
943
944 if not add:
945 # Swap saturation values and negate values as we need to perform opposite operations
946 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
947
948 # Create new array of unsaturated values by clipping values as needed
949 b_unsat_arr = b_arr
950 if (sat_max_arr != 0).any():
951 # Clip values that cause saturation
952 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
953 # Reduce axes in unsaturated tensor to match original tensor
954 for axis, dim in enumerate(b_arr.shape):
955 if dim != b_unsat_arr.shape[axis]:
956 assert (
957 dim == 1
958 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
959 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
960
961 if (sat_min_arr != 0).any():
962 # Clip values that cause saturation
963 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
964 # Reduce axes in unsaturated tensor to match original tensor
965 for axis, dim in enumerate(b_arr.shape):
966 if dim != b_unsat_arr.shape[axis]:
967 assert (
968 dim == 1
969 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
970 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
971
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000972 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100973 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
974 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000975 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100976 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
977 )
978
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000979 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100980 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000981 # ERROR_IF or floating point test
982 data_range = TosaTensorValuesGen._get_data_range(
983 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
984 )
985 if data_range:
986 argsDict["data_range"] = data_range
987
988 return TosaTensorValuesGen.tvgLazyGenDefault(
989 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100990 )
991
992 @staticmethod
993 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000994 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100995 ):
996 if dtypeList[0] in (
997 DType.INT32,
998 DType.INT16,
999 DType.INT8,
1000 ):
1001 # Limit input tensors with cond_if_binary or while_loop to stop
1002 # saturation of add/sub ops with int32 and keep all logical shift
1003 # values between 0 to 31 for int16 or int8
1004 pCount, cCount = op["operands"]
1005 pRemain = pCount
1006 placeholders = []
1007 for idx, shape in enumerate(shapeList[:]):
1008 if dtypeList[0] == DType.INT32:
1009 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
1010 else:
1011 arr = np.int32(
1012 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
1013 )
1014 if pRemain > 0:
1015 placeholders.append(
1016 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
1017 )
1018 pRemain -= 1
1019 else:
1020 placeholders.append(
1021 testGen.ser.addConst(shape, dtypeList[idx], arr)
1022 )
1023
1024 return placeholders
1025 else:
1026 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001027 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001028 )
1029
1030 @staticmethod
1031 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001032 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001033 ):
1034 pCount, cCount = op["operands"]
1035 # Force value of operand[1] to be within [0, num_bits]
1036 assert (
1037 pCount == 2 and cCount == 0
1038 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1039
1040 placeholders = []
1041 for idx, shape in enumerate(shapeList[:]):
1042 if idx == 1:
1043 if dtypeList[idx] == DType.INT8:
1044 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
1045 elif dtypeList[idx] == DType.INT16:
1046 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
1047 elif dtypeList[idx] == DType.INT32:
1048 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
1049 elif error_name == ErrorIf.WrongInputType:
1050 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
1051 else:
1052 raise Exception("OpArithmeticRightShift: invalid input dtype")
1053 else:
1054 arr = testGen.getRandTensor(shape, dtypeList[idx])
1055 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
1056
1057 return placeholders
1058
1059 @staticmethod
Won Jeon64e4bfe2024-01-18 06:31:55 +00001060 def tvgReshape(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
1061 dtypeList[1] = DType.SHAPE
1062 shapeList[1] = [len(argsDict["new_shape"])]
1063 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1064 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1065
1066 return TosaTensorValuesGen.tvgLazyGenDefault(
1067 testGen, op, dtypeList, shapeList, argsDict, error_name
1068 )
1069
1070 @staticmethod
1071 def tvgTile(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
1072 dtypeList[1] = DType.SHAPE
1073 shapeList[1] = [len(argsDict["multiples"])]
1074 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1075
1076 return TosaTensorValuesGen.tvgLazyGenDefault(
1077 testGen, op, dtypeList, shapeList, argsDict, error_name
1078 )
1079
1080 @staticmethod
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001081 def tvgSelect(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001082 # Set datatype of condition tensor to boolean
1083 dtypeList[0] = DType.BOOL
1084
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001085 return TosaTensorValuesGen.tvgLazyGenDefault(
1086 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001087 )
1088
1089 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001090 def tvgIntDiv(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001091 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001092 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001093 pCount, cCount = op["operands"]
1094 assert (
1095 pCount == 2 and cCount == 0
1096 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1097
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001098 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001099
1100 # Two invalid cases for Op.INTDIV:
1101 # 1. divisor == 0
1102 # 2. dividend == -(1<<31) and divisor == -1
1103 while True:
1104 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1105 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
1106
1107 if (divisor_arr == 0).any():
1108 continue
1109
1110 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1111 continue
1112
1113 break
1114
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001115 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001116 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1117 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001118 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001119 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1120 )
1121
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001122 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001123 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001124 return TosaTensorValuesGen.tvgLazyGenDefault(
1125 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001126 )
1127
Jeremy Johnson30476252023-11-20 16:15:30 +00001128 # Set the MUL data range to the square root of the largest value
1129 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001130 TVG_FLOAT_HIGH_VALUE_MUL = {
1131 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1132 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1133 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1134 }
1135
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001136 @staticmethod
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001137 def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1138 if error_name is not None or dtypeList[0] in (
1139 DType.FP16,
1140 DType.BF16,
1141 DType.FP32,
1142 ):
1143 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001144 data_range = TosaTensorValuesGen._get_data_range(
1145 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
1146 )
1147 if data_range:
1148 argsDict["data_range"] = data_range
1149
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001150 return TosaTensorValuesGen.tvgLazyGenDefault(
1151 testGen, opName, dtypeList, shapeList, argsDict, error_name
1152 )
1153 else:
1154 # Integer test
1155 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001156 pCount, cCount = op["operands"]
1157 assert (
1158 pCount == 2 and cCount == 0
1159 ), "Op.MUL must have 2 placeholders, 0 consts"
1160
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001161 tens_ser_list = []
1162
1163 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001164 if dtypeList[0] == DType.SHAPE:
1165 shift = 0
1166 else:
1167 shift = argsDict["shift"]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001168 if dtypeList[0] == DType.INT8:
1169 num_bits = 8
1170 elif dtypeList[0] == DType.INT16:
1171 num_bits = 16
Won Jeon74342e52024-01-09 00:34:40 +00001172 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001173 num_bits = 32
1174 elif error_name == ErrorIf.WrongInputType:
1175 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001176 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001177 raise Exception("OpMul: invalid input dtype")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001178
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001179 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001180 if dtypeList[idx] == DType.SHAPE:
1181 low = testGen.args.tensor_shape_range[0]
1182 high = testGen.args.tensor_shape_range[1]
1183 else:
1184 low = -(2 ** (num_bits - 1))
1185 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001186
1187 a_arr = np.int32(
1188 testGen.rng.integers(low=low, high=high, size=shapeList[0])
1189 )
1190 b_arr = np.int32(
1191 testGen.rng.integers(low=low, high=high, size=shapeList[1])
1192 )
1193
1194 i = 0
1195 while True:
1196
1197 a_arr_64 = a_arr.astype(np.int64)
1198 b_arr_64 = b_arr.astype(np.int64)
1199
1200 if shift > 0:
1201 rounding = 1 << (shift - 1)
1202 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001203 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001204 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001205
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001206 if (result_arr > -(2**31)).all() and (
1207 result_arr <= ((2**31) - 1)
1208 ).all():
1209 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001210
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001211 i = i + 1
1212 a_arr = a_arr // 2
1213 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001214
Won Jeon74342e52024-01-09 00:34:40 +00001215 if dtypeList[0] == DType.SHAPE:
1216 tens_ser_list.append(
1217 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1218 )
1219 tens_ser_list.append(
1220 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1221 )
1222 else:
1223 tens_ser_list.append(
1224 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1225 )
1226 tens_ser_list.append(
1227 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1228 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001229
1230 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001231
1232 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001233 def tvgConcat(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001234 count = len(shapeList) - testGen.args.num_const_inputs_concat
1235 if count < 1:
1236 count = 1
1237 if testGen.args.num_const_inputs_concat == 0:
1238 count = len(shapeList)
1239
Won Jeon74342e52024-01-09 00:34:40 +00001240 op = testGen.TOSA_OP_LIST[opName]
1241 if op["op"] == Op.CONCAT_SHAPE:
1242 # Set the axis to 0
1243 shapeList = TosaTensorGen.tgConcatConstInput(
1244 testGen, shapeList, 0, error_name
1245 )
1246 else:
1247 shapeList = TosaTensorGen.tgConcatConstInput(
1248 testGen, shapeList, argsDict["axis"], error_name
1249 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001250
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001251 # Override default pCount/cCount for operator
1252 argsDict["p_count"] = count
1253 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001254
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001255 return TosaTensorValuesGen.tvgLazyGenDefault(
1256 testGen, opName, dtypeList, shapeList, argsDict, error_name
1257 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001258
1259 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001260 def tvgLogicalShift(
1261 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1262 ):
1263 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001264 pCount, cCount = op["operands"]
1265 assert (
1266 pCount == 2 and cCount == 0
1267 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
1268 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1269 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001270 tens_ser_list = []
1271 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001272 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1273 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001274 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001275 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1276 )
1277
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001278 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001279
1280 @staticmethod
Jeremy Johnsona0150012023-11-15 15:52:06 +00001281 def tvgEqual(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1282 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1283 # Integer
1284 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001285 pCount, cCount = op["operands"]
1286 assert (
1287 pCount == 2 and cCount == 0
1288 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001289
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001290 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1291 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001292
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001293 # Using random numbers means that it will be very unlikely that
1294 # there are any matching (equal) values, therefore force that
1295 # there are twice the number of matching values as the tensor rank
1296 for num in range(0, len(shapeList[0]) * 2):
1297 a_index = []
1298 b_index = []
1299 # Choose an index in each axis for the whole shape
1300 for axis in range(0, len(shapeList[0])):
1301 # Index can be up to the largest dimension in both shapes
1302 index = np.int32(
1303 testGen.rng.integers(
1304 0, max(shapeList[0][axis], shapeList[1][axis])
1305 )
1306 )
1307 # Reduce the index down to a shape's dim for broadcasting
1308 a_index.append(min(shapeList[0][axis] - 1, index))
1309 b_index.append(min(shapeList[1][axis] - 1, index))
1310
1311 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1312
Jeremy Johnsona0150012023-11-15 15:52:06 +00001313 tens_ser_list = []
1314 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001315 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1316 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001317 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001318 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1319 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001320 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001321 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001322 # ERROR_IF or floating point test
1323 return TosaTensorValuesGen.tvgLazyGenDefault(
1324 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001325 )
1326
1327 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001328 def tvgReduceSum(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001329 dtype = dtypeList[0]
1330 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001331 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001332 pCount, cCount = op["operands"]
1333 assert (
1334 pCount == 1 and cCount == 0
1335 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1336 # Limit values so that the sum cannot exceed the range of an int32 during
1337 # summation of any axis
1338 range_val = int((1 << 31) / max(shapeList[0]))
1339 values_arr = np.int32(
1340 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1341 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001342 tens_ser_list = []
1343 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001344 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001345 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001346 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001347 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001348 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001349 if (
1350 error_name is None
1351 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1352 ):
1353 # Limit ranges for (non error & non compliance) tests by using
1354 # values that can be summed on any axis to not hit infinity
1355 highval_lookup = {
1356 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1357 / max(shapeList[0])
1358 }
1359 data_range = TosaTensorValuesGen._get_data_range(
1360 testGen, dtype, highval_lookup
1361 )
1362 assert data_range is not None
1363 argsDict["data_range"] = data_range
1364
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001365 return TosaTensorValuesGen.tvgLazyGenDefault(
1366 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001367 )
1368
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001369 @staticmethod
1370 def tvgReduceProduct(
1371 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1372 ):
1373 dtype = dtypeList[0]
1374 if error_name is None:
1375 # Limit ranges for (non error) tests by using
1376 # values that can be multiplied on any axis to not hit infinity
1377 highval_lookup = {
1378 dtype: math.pow(
1379 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1380 1 / max(shapeList[0]),
1381 )
1382 }
1383 data_range = TosaTensorValuesGen._get_data_range(
1384 testGen, dtype, highval_lookup
1385 )
1386 assert data_range is not None
1387 argsDict["data_range"] = data_range
1388
1389 return TosaTensorValuesGen.tvgLazyGenDefault(
1390 testGen, opName, dtypeList, shapeList, argsDict, error_name
1391 )
1392
Jeremy Johnson30476252023-11-20 16:15:30 +00001393 # Set the POW exponent high data range
1394 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1395 DType.FP32: 10.0,
1396 DType.FP16: 10.0,
1397 DType.BF16: 10.0,
1398 }
1399 # POW highest base value (within a safe margin of error) that can be raised
1400 # to +ve exponent that doesn't become Infinity
1401 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1402 DType.FP32: math.floor(
1403 math.pow(
1404 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1405 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1406 )
1407 ),
1408 DType.FP16: math.floor(
1409 math.pow(
1410 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1411 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1412 )
1413 ),
1414 DType.BF16: math.floor(
1415 math.pow(
1416 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1417 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1418 )
1419 ),
1420 }
1421 # POW lowest base value (within a safe margin of error) that can be raised
1422 # to -ve exponent that doesn't become Infinity
1423 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1424 DType.FP32: math.ceil(
1425 math.pow(
1426 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1427 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1428 )
1429 * 1000
1430 )
1431 / 1000,
1432 DType.FP16: math.ceil(
1433 math.pow(
1434 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1435 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1436 )
1437 * 1000
1438 )
1439 / 1000,
1440 DType.BF16: math.ceil(
1441 math.pow(
1442 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1443 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1444 )
1445 * 1000
1446 )
1447 / 1000,
1448 }
1449
1450 @staticmethod
1451 def tvgPow(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1452 if error_name is not None:
1453 return TosaTensorValuesGen.tvgLazyGenDefault(
1454 testGen, opName, dtypeList, shapeList, argsDict, error_name
1455 )
1456 dtype = dtypeList[0]
1457 # Different ranges for POW
1458 test_set = argsDict["s"]
1459 if test_set == 0:
1460 # Positive base with fractional exponent
1461 base_range = TosaTensorValuesGen._get_data_range(
1462 testGen,
1463 dtype,
1464 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1465 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1466 )
1467 exp_range = TosaTensorValuesGen._get_data_range(
1468 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1469 )
1470 exp_round = False
1471 else:
1472 # Integer exponent
1473 exp_range = TosaTensorValuesGen._get_data_range(
1474 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1475 )
1476 exp_round = True
1477 if test_set == 1:
1478 # Positive base
1479 base_range = TosaTensorValuesGen._get_data_range(
1480 testGen,
1481 dtype,
1482 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1483 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1484 )
1485 else:
1486 assert test_set == 2
1487 # Negative base
1488 # Supply new look up tables with negative values
1489 base_range = TosaTensorValuesGen._get_data_range(
1490 testGen,
1491 dtype,
1492 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1493 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1494 )
1495
1496 data_range_list = (
1497 {
1498 "range": base_range,
1499 },
1500 {
1501 "range": exp_range,
1502 "round": exp_round,
1503 },
1504 )
1505 argsDict["data_range_list"] = data_range_list
1506 return TosaTensorValuesGen.tvgLazyGenDefault(
1507 testGen, opName, dtypeList, shapeList, argsDict, error_name
1508 )
1509
1510 @staticmethod
1511 def tvgLogRsqrt(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1512 # LOG & RSQRT data range from lowest expressible positive number to
1513 # largest to avoid NaNs
1514 data_range = TosaTensorValuesGen._get_data_range(
1515 testGen,
1516 dtypeList[0],
1517 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1518 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1519 )
1520 if data_range:
1521 argsDict["data_range"] = data_range
1522
1523 return TosaTensorValuesGen.tvgLazyGenDefault(
1524 testGen, opName, dtypeList, shapeList, argsDict, error_name
1525 )
1526
1527 # Set the EXP data range to the log of the largest to smallest values
1528 # to avoid infinities or making the result zero
1529 TVG_FLOAT_HIGH_VALUE_EXP = {
1530 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1531 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1532 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1533 }
1534 TVG_FLOAT_LOW_VALUE_EXP = {
1535 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1536 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1537 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1538 }
1539
1540 @staticmethod
1541 def tvgExp(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1542 data_range = TosaTensorValuesGen._get_data_range(
1543 testGen,
1544 dtypeList[0],
1545 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1546 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1547 )
1548 if data_range:
1549 argsDict["data_range"] = data_range
1550
1551 return TosaTensorValuesGen.tvgLazyGenDefault(
1552 testGen, opName, dtypeList, shapeList, argsDict, error_name
1553 )
1554
1555 @staticmethod
1556 def tvgFullyConnected(
1557 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1558 ):
1559 dtype = dtypeList[0]
1560 if (
1561 error_name is None
1562 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001563 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001564 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001565 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001566 # Limit ranges for (non error & non compliance) FP tests by using
1567 # values that can be multiplied on any axis to not hit infinity/NaN
1568 IC = shapeList[0][1]
1569 highval_lookup = {
1570 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1571 }
1572 data_range = TosaTensorValuesGen._get_data_range(
1573 testGen, dtype, highval_lookup
1574 )
1575 assert data_range is not None
1576 argsDict["data_range"] = data_range
1577
1578 return TosaTensorValuesGen.tvgLazyGenDefault(
1579 testGen, opName, dtypeList, shapeList, argsDict, error_name
1580 )
1581
Jeremy Johnson708da822023-11-15 16:25:45 +00001582 @staticmethod
1583 def tvgCast(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1584 in_dtype = dtypeList[0]
1585 out_dtype = argsDict["out_type"]
1586 # Create look up to limit input tensor to output type maximums to avoid
1587 # FP infinities and saturation of integers
1588 out_range = testGen.getDTypeRange(out_dtype, high_inclusive=True)
1589 highval_lookup = {in_dtype: out_range[1]}
1590 data_range = TosaTensorValuesGen._get_data_range(
1591 testGen,
1592 in_dtype,
1593 highval_lookup,
1594 )
1595
1596 assert data_range is not None
1597 argsDict["data_range"] = data_range
1598
1599 return TosaTensorValuesGen.tvgLazyGenDefault(
1600 testGen, opName, dtypeList, shapeList, argsDict, error_name
1601 )
1602
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001603 @staticmethod
1604 def tvgGather(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1605 K = shapeList[0][1]
1606
1607 # Fix the type of the indices tensor
1608 dtypeList[1] = DType.INT32
1609
1610 dtype = dtypeList[0]
1611 if not gtu.dtypeIsSupportedByCompliance(dtype):
1612 # Test unsupported by data generator
1613 op = testGen.TOSA_OP_LIST[opName]
1614 pCount, cCount = op["operands"]
1615 assert (
1616 pCount == 2 and cCount == 0
1617 ), "Op.GATHER must have 2 placeholders, 0 consts"
1618
1619 tens_ser_list = []
1620 for idx, shape in enumerate(shapeList):
1621 dtype = dtypeList[idx]
1622 if idx != 1:
1623 arr = testGen.getRandTensor(shape, dtype)
1624 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1625 else:
1626 # Limit data range of indices tensor upto K (exclusive)
1627 arr = testGen.getRandTensor(shape, dtype, (0, K))
1628 # To match old functionality - create indices as CONST
1629 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1630
1631 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1632
1633 else:
1634 # ERROR_IF or floating point test
1635 # Use inclusive values upto index K for indices tensor
1636 data_range_list = (
1637 {"range": None},
1638 {"range": (0, K - 1)},
1639 )
1640 argsDict["data_range_list"] = data_range_list
1641
1642 return TosaTensorValuesGen.tvgLazyGenDefault(
1643 testGen, opName, dtypeList, shapeList, argsDict, error_name
1644 )
1645
1646 @staticmethod
1647 def tvgScatter(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1648 K = shapeList[0][1]
1649 W = shapeList[2][1]
1650
1651 # Work out an indices tensor here with data that doesn't exceed the
1652 # dimension K of the values_in tensor and does NOT repeat the same K
1653 # location as needed by the spec:
1654 # "It is not permitted to repeat the same output index within a single
1655 # SCATTER operation and so each output index occurs at most once."
1656 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1657
1658 # Fix the type of the indices tensor
1659 dtypeList[1] = DType.INT32
1660
1661 dtype = dtypeList[0]
1662 if not gtu.dtypeIsSupportedByCompliance(dtype):
1663 # Test unsupported by data generator
1664 op = testGen.TOSA_OP_LIST[opName]
1665 pCount, cCount = op["operands"]
1666 assert (
1667 pCount == 3 and cCount == 0
1668 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1669
1670 tens_ser_list = []
1671 for idx, shape in enumerate(shapeList):
1672 dtype = dtypeList[idx]
1673 if idx != 1:
1674 arr = testGen.getRandTensor(shape, dtype)
1675 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1676 else:
1677 # Create the indices array
1678 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1679 arr = []
1680 for n in range(shape[0]):
1681 # Get a shuffled list of output indices (0 to K-1) and
1682 # limit length to W
1683 arr.append(testGen.rng.permutation(K)[:W])
1684 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1685 # To match old functionality - create indices as CONST
1686 tens_ser_list.append(
1687 testGen.ser.addConst(shape, dtype, indices_arr)
1688 )
1689
1690 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1691
1692 else:
1693 # ERROR_IF or floating point test
1694 # Use inclusive values upto index K for indices tensor
1695 data_range_list = (
1696 {"range": None},
1697 {"range": (0, K - 1)},
1698 {"range": None},
1699 )
1700 argsDict["data_range_list"] = data_range_list
1701
1702 return TosaTensorValuesGen.tvgLazyGenDefault(
1703 testGen, opName, dtypeList, shapeList, argsDict, error_name
1704 )
1705
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001706
1707class TosaArgGen:
1708 """Argument generators create exhaustive or random lists of attributes for
1709 operators that take attributes or other parameters.
1710
1711 The return value is a list of (descriptive_name, [arglist]) tuples where
1712 the descriptive_name is appended to the test name and the arglist is expanded
1713 as arguments to the operator build function.
1714 """
1715
1716 def __init__(self):
1717 pass
1718
1719 @staticmethod
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001720 def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001721 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001722 if (
1723 error_name is None
1724 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1725 and gtu.dtypeIsSupportedByCompliance(dtype)
1726 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001727 if dtype in [DType.FP16, DType.FP32, DType.BF16]:
1728 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1729 else:
1730 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1731 else:
1732 # Error test or No data generator types listed - assume random
1733 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1734
1735 # Expand arg list with other data generator types
1736 new_arg_list = []
1737 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001738 for arg_str, args_dict in arg_list:
1739 args_dict["dg_type"] = dg_type
Jeremy Johnson1271c442023-09-05 11:39:26 +01001740 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001741 if error_name is None:
1742 num_test_sets = (
1743 args_dict["num_test_sets"]
1744 if "num_test_sets" in args_dict
1745 else 0
1746 )
1747 else:
1748 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001749
1750 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1751 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001752 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001753 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001754 shape_info = (
1755 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1756 if "shape" in args_dict
1757 else ""
1758 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001759 print(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001760 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001761 )
1762 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001763 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001764 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001765 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001766
Jeremy Johnson30476252023-11-20 16:15:30 +00001767 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1768
1769 if num_test_sets > 0:
1770 for s in range(0, num_test_sets):
1771 new_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001772 new_args_dict = args_dict.copy()
1773 new_args_dict["s"] = s
1774 new_arg_list.append((new_arg_str, new_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001775 else:
1776 # Default is a single test
1777 new_arg_list.append((arg_str, args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001778
1779 return new_arg_list
1780
1781 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001782 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1783 """A trivial argument generator for operators that don't take any
1784 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001785 arg_list = TosaArgGen._add_data_generators(
1786 testGen,
1787 opName,
1788 dtype,
1789 [("", {})],
1790 error_name,
1791 )
1792 # Return list of tuples: (arg_str, args_dict)
1793 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001794
1795 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +00001796 def agPow(testGen, opName, shapeList, dtype, error_name=None):
1797 """Pow operator needs different test sets to cover random numbers
1798 without creating NaNs or Infs"""
1799 arg_list = TosaArgGen._add_data_generators(
1800 testGen,
1801 opName,
1802 dtype,
1803 [("", {"num_test_sets": 3})],
1804 error_name,
1805 )
1806 # Return list of tuples: (arg_str, args_dict)
1807 return arg_list
1808
1809 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001810 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1811 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001812 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001813 shape = shapeList[0]
1814
1815 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001816 # Set too small axis
1817 axes = [testGen.rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001818 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001819 # Set too large axis
1820 axes = [testGen.rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001821 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001822 # Create tests for each dimension
1823 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001824
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001825 opid = testGen.TOSA_OP_LIST[opName]["op"]
1826
1827 for a in axes:
1828 args_dict = {"axis": int(a)}
1829 if opid == Op.REDUCE_SUM:
1830 args_dict["dot_products"] = gtu.product(shape)
1831 args_dict["shape"] = shape
1832 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1833 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1834
1835 arg_list.append(("axis{}".format(a), args_dict))
1836
1837 arg_list = TosaArgGen._add_data_generators(
1838 testGen,
1839 opName,
1840 dtype,
1841 arg_list,
1842 error_name,
1843 )
1844 # Return list of tuples: (arg_str, args_dict)
1845 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001846
1847 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001848 def _calculate_sparsity(num_tests, sparsity_factor):
1849 sparsity = num_tests // sparsity_factor + 1
1850 # If there are only a small number of tests, just select them all
1851 if sparsity < 13:
1852 sparsity = 1
1853 # To get a variety of parameter combinations sparsity should not be a
1854 # multiple of 2, 3 or 5
1855 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1856 sparsity += 1
1857 return sparsity
1858
1859 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001860 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001861 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001862 arg_list = []
1863
Jeremy Johnson0c716862023-04-13 17:18:19 +01001864 if testGen.args.level8k and error_name is not None:
1865 # Don't produce negative large tests
1866 return arg_list
1867
1868 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001869 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001870 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001871 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001872
Jeremy Johnson1271c442023-09-05 11:39:26 +01001873 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001874
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001875 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01001876 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001877 depthwise = opName.startswith("depthwise")
1878
1879 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001880 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001881 if error_name != ErrorIf.WrongRank:
1882 assert len(ifm_shape) == rank
1883 assert len(filter_shape) == rank
1884
Jeremy Johnson0c716862023-04-13 17:18:19 +01001885 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001886 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001887 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01001888 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001889 # compliance size - KS
1890 k_size = gtu.product(k_shape)
1891 if not depthwise:
1892 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001893
Jeremy Johnson0c716862023-04-13 17:18:19 +01001894 if not testGen.args.level8k:
1895 # Generate comprehensive argument lists
1896 # - except for named errors, which use specific invalid value(s)
1897 if error_name == ErrorIf.PadSmallerZero:
1898 p_vals = [testGen.rng.choice(range(-5, 0))]
1899 else:
1900 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1901 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1902 if error_name == ErrorIf.StrideSmallerOne:
1903 # Can't use stride=0, as it is used to derive output shape, as a divisor
1904 s_vals = [testGen.rng.choice(range(-5, 0))]
1905 else:
1906 # Stride must be greater than 1 to force non-integer error
1907 startStride = (
1908 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001909 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001910 s_vals = [
1911 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1912 ]
1913 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1914 if error_name == ErrorIf.DilationSmallerOne:
1915 d_vals = [testGen.rng.choice(range(-5, 1))]
1916 else:
1917 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1918 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001919
Jeremy Johnson0c716862023-04-13 17:18:19 +01001920 if not error_name and testGen.args.oversize:
1921 # add some oversize argument values
1922 if max(ifm_shape) < 64:
1923 bigPadding = 9
1924 paddings.update(
1925 {
1926 x
1927 for x in itertools.product(
1928 *([[0, bigPadding]] * (k_rank * 2))
1929 )
1930 }
1931 )
1932 bigStride = 8
1933 strides.update(
1934 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1935 )
1936 bigDilation = 7
1937 dilations.update(
1938 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1939 )
1940 max_dim_size = None
1941
1942 # There are too many parameter combinations, so generate them sparsely,
1943 # very sparse for negative tests
1944 sparsity_factor = 2 if error_name else 120
1945 sparsity = TosaArgGen._calculate_sparsity(
1946 len(paddings) * len(strides) * len(dilations), sparsity_factor
1947 )
1948 else:
1949 # Only test 8k levels boundaries
1950 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1951 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1952 bigPadding = bigKernel
1953
1954 dilation_shape = [1] * k_rank
1955 pad_shape = [0] * k_rank * 2
1956 if conv3d:
1957 # Small stride apart from for big kernel (see below) to keep
1958 # tensor size/calculation small
1959 stride_shape = [1] * k_rank
1960 for idx in range(k_rank):
1961 pad_offset = idx * 2
1962 if k_shape[idx] == bigKernel:
1963 # Padding shape needs to account for tensor shape
1964 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1965 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1966 # Big stride to reduce output size
1967 stride_shape[idx] = bigKernel
1968 else:
1969 # Account for kernel size
1970 pad_shape[pad_offset] = k_shape[idx] - 1
1971 else:
1972 # Always have a large stride with extra padding and dilation to keep
1973 # tensor calculation reasonable
1974 stride_shape = [bigKernel] * k_rank
1975 for idx in range(k_rank):
1976 # Dilation shape must account for kernel size
1977 dilation_shape[idx] = bigKernel // k_shape[idx]
1978 # Padding shape needs to accommodate tensor/kernel & dilation
1979 pad_offset = idx * 2
1980 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1981 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1982
1983 strides = {tuple(stride_shape)}
1984 dilations = {tuple(dilation_shape)}
1985 paddings = {tuple(pad_shape)}
1986 # Create a limit for the output dimensions size
1987 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1988
1989 # Currently allow all combinations that are reasonable size
1990 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001991
1992 n = 0
1993 for s in sorted(list(strides)):
1994 for p in sorted(list(paddings)):
1995 for d in sorted(list(dilations)):
1996 if (
1997 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001998 # the padded shape must exceed the dilation * kernel to get a positive
1999 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002000 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
2001 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01002002 and (
2003 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01002004 or (
2005 (ifm_shape[3] - 1 + p[4] + p[5])
2006 > d[2] * (k_shape[2] - 1)
2007 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01002008 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002009 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002010 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01002011 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002012 for index in range(k_rank):
2013 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002014 partial = (
2015 ifm_shape[index + 1]
2016 - 1
2017 + p[pad_offset]
2018 + p[pad_offset + 1]
2019 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002020 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002021 remainders.append(partial % s[index])
2022 outputs.append((partial // s[index]) + 1)
2023
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002024 if (
2025 # the parameters must produce integer exact output
2026 error_name != ErrorIf.ConvOutputShapeNonInteger
2027 and max(remainders) == 0
2028 ) or (
2029 error_name == ErrorIf.ConvOutputShapeNonInteger
2030 and max(remainders) > 0
2031 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002032 if (
2033 max_dim_size is not None
2034 and max(outputs) >= max_dim_size
2035 ):
2036 # Test will consume too much memory - skip it
2037 continue
2038
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002039 # Compliance - number of dot product calculations
2040 if depthwise:
Jeremy Johnson4f931302024-01-04 17:05:24 +00002041 # N*OH*OW*C*M
2042 dots = gtu.product(
2043 (ifm_shape[0], *outputs, *filter_shape[2:])
2044 )
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002045 else:
Jeremy Johnson4f931302024-01-04 17:05:24 +00002046 # N*OH*OW*OC or N*OD*OH*OW*OC
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002047 dots = gtu.product(
2048 (ifm_shape[0], *outputs, filter_shape[0])
2049 )
2050 args_dict = {
2051 "acc_type": accum_dtype,
2052 "stride": s,
2053 "pad": p,
2054 "dilation": d,
2055 "kernel": k_shape,
2056 "ks": k_size,
2057 "dot_products": dots,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002058 "shape": ifm_shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002059 }
2060
Jeremy Johnson0c716862023-04-13 17:18:19 +01002061 # Support for larger values than 9 needs different delimiter
2062 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002063 arg_list.append(
2064 (
James Ward8b390432022-08-12 20:48:56 +01002065 "acc{}_st{}_pad{}_dilat{}".format(
2066 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002067 delim.join([str(x) for x in s]),
2068 delim.join([str(x) for x in p]),
2069 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002070 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002071 args_dict,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002072 )
2073 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002074 n += 1
2075
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002076 arg_list = TosaArgGen._add_data_generators(
2077 testGen,
2078 opName,
2079 dtypes[0],
2080 arg_list,
2081 error_name,
2082 )
2083 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002084 return arg_list
2085
2086 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01002087 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
2088
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002089 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002090 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002091
2092 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002093 accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002094 elif error_name == ErrorIf.WrongInputType:
2095 # Pick some potentially correct output dtype if input type is incorrect
2096 accum_dtype = DType.INT32
2097 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002098 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002099
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002100 # Set up compliance info
2101 args_dict = {
2102 "acc_type": accum_dtype,
2103 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2104 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2105 "shape": shapeList[0],
2106 }
2107
2108 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2109
2110 arg_list = TosaArgGen._add_data_generators(
2111 testGen,
2112 opName,
2113 input_dtype,
2114 arg_list,
2115 error_name,
2116 )
2117 # Return list of tuples: (arg_str, args_dict)
2118 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002119
2120 @staticmethod
2121 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
2122 # Get valid accumulate type(s)
2123 if dtype == DType.INT8:
2124 accum_dtypes = [DType.INT32]
2125 elif dtype == DType.INT16:
2126 accum_dtypes = [DType.INT48]
2127 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002128 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002129 elif dtype == DType.BF16:
2130 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002131 elif dtype == DType.FP32:
2132 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01002133 elif error_name is None:
2134 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2135
2136 if error_name == ErrorIf.WrongOutputType:
2137 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson1271c442023-09-05 11:39:26 +01002138 accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002139 elif error_name == ErrorIf.WrongInputType:
2140 # Pick some potentially correct output dtype if input type is incorrect
2141 accum_dtypes = [DType.INT32]
2142
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002143 # Set up compliance info
2144 args_dict = {
2145 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2146 # Set dot_products = N*H*W
2147 "dot_products": gtu.product(
2148 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2149 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002150 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002151 }
2152
2153 # Create arg tuple of string and dict
2154 arg_list = []
2155 for a in accum_dtypes:
2156 d = args_dict.copy()
2157 d["acc_type"] = a
2158 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002159
2160 arg_list = TosaArgGen._add_data_generators(
2161 testGen,
2162 opName,
2163 dtype,
2164 arg_list,
2165 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002166 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002167 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002168 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002169
2170 @staticmethod
2171 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002172 arg_list = []
2173
Jeremy Johnson0c716862023-04-13 17:18:19 +01002174 if testGen.args.level8k and error_name is not None:
2175 # Don't produce negative large tests
2176 return arg_list
2177
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002178 ifm_shape = shapeList[0]
2179 filter_shape = shapeList[1]
2180
Jeremy Johnson1271c442023-09-05 11:39:26 +01002181 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002182
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002183 # Must be rank 4
2184 if error_name != ErrorIf.WrongRank:
2185 assert len(ifm_shape) == 4
2186 assert len(filter_shape) == 4
2187
Jeremy Johnson0c716862023-04-13 17:18:19 +01002188 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002189
Jeremy Johnson0c716862023-04-13 17:18:19 +01002190 if not testGen.args.level8k:
2191 # Generate comprehensive argument lists
2192 # - except for named errors, which use specific invalid value(s)
2193 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2194 if error_name == ErrorIf.PadLargerEqualKernel:
2195 max_filter_size = -max(k_shape[0], k_shape[1])
2196 p_vals = [
2197 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
2198 ]
2199 else:
2200 p_vals = [
2201 x
2202 for x in range(
2203 smallest_padding_size, testGen.args.max_conv_padding + 1
2204 )
2205 ]
2206 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2207 if error_name == ErrorIf.StrideSmallerOne:
2208 # Can't use stride=0, as it is used to derive output shape, as a divisor
2209 s_vals = [testGen.rng.choice(range(-5, 0))]
2210 else:
2211 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2212 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002213
Jeremy Johnson0c716862023-04-13 17:18:19 +01002214 if not error_name and testGen.args.oversize:
2215 # add some oversize argument values
2216 if max(ifm_shape) < 64:
2217 bigPadding = 9
2218 paddings.update(
2219 {
2220 x
2221 for x in itertools.product(
2222 *([[smallest_padding_size, bigPadding]] * 4)
2223 )
2224 }
2225 )
2226 bigStride = 8
2227 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2228
2229 # There are too many parameter combinations, so generate them sparsely,
2230 # very sparse for negative tests
2231 sparsity_factor = 2 if error_name else 10
2232 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2233 # If there are only a small number of tests, just select them all
2234 if sparsity < 13:
2235 sparsity = 1
2236 # To get a variety of parameter combinations sparsity should not be a
2237 # multiple of 2, 3 or 5
2238 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2239 sparsity += 1
2240 else:
2241 # Only test 8k levels boundaries
2242 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2243 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2244 bigPadding = bigKernel
2245
2246 pad_shape = [0] * (len(k_shape) * 2)
2247 stride_shape = [1] * len(k_shape)
2248 # The point at which input dimension combined with the stride will
2249 # create large output sizes!
2250 LARGE_SIZE = 2
2251 for idx in range(len(k_shape)):
2252 pad_offset = idx * 2
2253 if k_shape[idx] == bigKernel:
2254 # Set large stride
2255 stride_shape[idx] = bigKernel
2256 # Use negative output padding to reduce shape size
2257 pad_shape[pad_offset] = -(bigPadding - 1)
2258 if ifm_shape[idx + 1] > LARGE_SIZE:
2259 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2260 else:
2261 # The other dimension should be the bigKernel
2262 alt_idx = 1 - idx
2263 if (
2264 k_shape[alt_idx] == bigKernel
2265 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2266 ):
2267 # As the input is small, the large stride won't
2268 # affect the output so we can add some padding
2269 pad_shape[pad_offset + 1] = bigPadding
2270
2271 strides = {tuple(stride_shape)}
2272 paddings = {tuple(pad_shape)}
2273
2274 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002275 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002276
2277 n = 0
2278 for s in sorted(list(strides)):
2279 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07002280 if n % sparsity == 0:
2281 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002282 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2283 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07002284 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002285
2286 # Support for larger values than 9 needs different delimiter
2287 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07002288 arg_list.append(
2289 (
James Ward8b390432022-08-12 20:48:56 +01002290 "acc{}_st{}_pad{}_os{}".format(
2291 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002292 delim.join([str(x) for x in s]),
2293 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07002294 "x".join([str(x) for x in os]),
2295 ),
James Ward8b390432022-08-12 20:48:56 +01002296 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002297 )
TatWai Chong24594f52022-06-08 00:48:04 -07002298 )
2299 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002300
2301 return arg_list
2302
2303 @staticmethod
2304 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002305 rank = len(shapeList[0])
2306
2307 # Exhaustively test combinations of padding on each side of each dimension
2308 # - the range of padding values is defined by pad_min and pad_max
2309 # - for padding >9, the name format needs to be more distinctive
2310 pad_min, pad_max = 0, 1
2311 pad_values = [x for x in range(pad_min, pad_max + 1)]
2312 if error_name == ErrorIf.PadSmallerZero:
2313 pad_values = [x for x in range(-2, 0)]
2314 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2315 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2316
2317 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
2318 pad_const_int = testGen.getRandNumberDType(dtype)
2319 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00002320 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002321 pad_const_int = 0
2322 pad_const_fp = testGen.getRandNumberDType(dtype)
2323 else:
2324 return []
2325
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002326 list_shape_pad_values = list(shape_pad_values)
2327 # If we are producing tests for rank 6 or greater use sparsity
2328 if len(list_shape_pad_values) > 1024:
2329 sparsity_factor = 2 if error_name else 120
2330 sparsity = TosaArgGen._calculate_sparsity(
2331 len(list_shape_pad_values), sparsity_factor
2332 )
2333 else:
2334 sparsity = 1
2335
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002336 # Build arg list
2337 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002338 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002339 paddings = list(paddings)
2340 args_valid = True
2341
2342 if error_name == ErrorIf.PadSmallerZero:
2343 # Prevent negative output shapes while ensuring still testing for negative padding
2344 for i in range(rank):
2345 dim_after_padding = (
2346 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2347 )
2348 if dim_after_padding < 1:
2349 paddings[i] = (0, 0)
2350 if all([p > -1 for p in paddings[i]]):
2351 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002352 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01002353 name = "pad"
2354 for r in range(rank):
2355 before, after = paddings[r]
2356 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002357 args_dict = {
2358 "pad": np.array(paddings),
2359 "pad_const_int": pad_const_int,
2360 "pad_const_fp": pad_const_fp,
2361 }
2362 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002363
2364 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
2365 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002366
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002367 arg_list = TosaArgGen._add_data_generators(
2368 testGen,
2369 opName,
2370 dtype,
2371 arg_list,
2372 error_name,
2373 )
2374
2375 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002376 return arg_list
2377
2378 @staticmethod
2379 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
2380 arg_list = []
2381
2382 shape = shapeList[0]
2383 if error_name != ErrorIf.WrongRank:
2384 assert len(shape) == 4
2385
Jeremy Johnson0c716862023-04-13 17:18:19 +01002386 test_level8k = testGen.args.level8k and error_name is None
2387
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002388 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002389 startKernel = 2
2390 startPad = 0
2391 if not test_level8k:
2392 # Generate comprehensive argument lists
2393 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2394 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2395 # Stride must be greater than 1 to force non-integer error
2396 s_vals = [
2397 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2398 ]
2399 strides = {x for x in itertools.product(*([s_vals] * 2))}
2400 k_vals = [
2401 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2402 ]
2403 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2404 max_dim_size = None
2405 else:
2406 # Only test 8k levels
2407 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2408 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2409 strides = {(1, bigStride), (bigStride, 4)}
2410 kernels = {(1, bigKernel), (bigKernel, 3)}
2411 paddings = set()
2412 for s in sorted(list(strides)):
2413 for k in sorted(list(kernels)):
2414 padding = []
2415 for idx in range(len(k)):
2416 total_padding = s[idx] - shape[idx + 1] + k[idx]
2417 while total_padding < 0:
2418 # Must meet: shape + padding > kernel
2419 total_padding += s[idx]
2420 if total_padding < k[idx]:
2421 padding.extend([0, total_padding])
2422 else:
2423 # Note this may produce padding >= k[idx] which is not
2424 # allowed - but will be ignored in the creation loop below
2425 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2426 paddings.add(tuple(padding))
2427 # Create a limit for the output dimensions size
2428 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002429
James Ward8b390432022-08-12 20:48:56 +01002430 if opName == "max_pool2d":
2431 accum_dtypes = [None] # max_pool has no accumulate dtype
2432 elif dtype == DType.INT8 or dtype == DType.INT16:
2433 accum_dtypes = [DType.INT32]
2434 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002435 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002436 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002437 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01002438 elif error_name is None:
2439 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2440 else:
2441 # Set to something for the ErrorIf case which has
2442 # incorrect input data-type
2443 accum_dtypes = [DType.INT32]
2444
Jeremy Johnson0c716862023-04-13 17:18:19 +01002445 if not test_level8k:
2446 if testGen.args.oversize:
2447 # add some oversize argument values
2448 bigStride = 7
2449 bigKernel = 9
2450 strides.update(
2451 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002452 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002453 kernels.update(
2454 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2455 )
2456 if max(shape) < 64:
2457 # padding must be less than the kernel size
2458 bigPadding = bigKernel - 1
2459 paddings.update(
2460 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2461 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002462
Jeremy Johnson0c716862023-04-13 17:18:19 +01002463 # There are too many parameter combinations, so generate them sparsely,
2464 # very sparse for negative tests
2465 sparsity_factor = 2 if error_name else 500
2466 sparsity = (
2467 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2468 )
2469 else:
2470 # We have already limited test output combinations for 8k tests
2471 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002472
James Ward8b390432022-08-12 20:48:56 +01002473 arg_str = (
2474 "acc{}_st{}_kern{}_pad{}"
2475 if accum_dtypes[0] is not None
2476 else "st{}_kern{}_pad{}"
2477 )
2478
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002479 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002480 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002481 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002482
2483 # Support for larger values than 9 needs different delimiter
2484 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002485 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002486 delim.join([str(x) for x in stride]),
2487 delim.join([str(x) for x in kern]),
2488 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002489 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002490 args_dict = {
2491 "stride": stride,
2492 "pad": pad,
2493 "kernel": kern,
2494 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002495 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002496 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2497 }
James Ward8b390432022-08-12 20:48:56 +01002498
2499 if accum is not None:
2500 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002501 args_dict["acc_type"] = accum
2502 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002503
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002504 n = 0
James Ward8b390432022-08-12 20:48:56 +01002505 for a in accum_dtypes:
2506 for s in sorted(list(strides)):
2507 for p in sorted(list(paddings)):
2508 for k in sorted(list(kernels)):
2509 if error_name in [
2510 ErrorIf.StrideSmallerOne,
2511 ErrorIf.KernelSmallerOne,
2512 ErrorIf.PadSmallerZero,
2513 ErrorIf.PadLargerEqualKernel,
2514 ]:
2515 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
2516 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002517 )
James Ward8b390432022-08-12 20:48:56 +01002518 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002519 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002520 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002521 )
James Ward8b390432022-08-12 20:48:56 +01002522 elif (
2523 n % sparsity == 0
2524 # padding must not exceed the kernel size
2525 and p[0] < k[0]
2526 and p[1] < k[0]
2527 and p[2] < k[1]
2528 and p[3] < k[1]
2529 # the padded shape must exceed the kernel size
2530 and (shape[1] + p[0] + p[1]) > k[0]
2531 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002532 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002533 partial_h = shape[1] + p[0] + p[1] - k[0]
2534 partial_w = shape[2] + p[2] + p[3] - k[1]
2535 remainder_h = partial_h % s[0]
2536 remainder_w = partial_w % s[1]
2537 output_h = partial_h // s[0] + 1
2538 output_w = partial_w // s[1] + 1
2539 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01002540 if (
2541 # the parameters must produce integer exact output
2542 error_name != ErrorIf.PoolingOutputShapeNonInteger
2543 and remainder_h == 0
2544 and remainder_w == 0
2545 ) or (
2546 error_name == ErrorIf.PoolingOutputShapeNonInteger
2547 and (remainder_h != 0 or remainder_w != 0)
2548 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002549 if (
2550 max_dim_size is not None
2551 and max(output_h, output_w) > max_dim_size
2552 ):
2553 # Test will consume too much memory - skip it
2554 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002555 # Dot products = N*OH*OW*C
2556 dp = gtu.product(
2557 (shape[0], output_h, output_w, shape[3])
2558 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002559 arg_list.append(
2560 get_arg_list_element(a, s, p, k, dp, shape)
2561 )
James Ward8b390432022-08-12 20:48:56 +01002562 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002563
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002564 # Now add data generator types
2565 arg_list = TosaArgGen._add_data_generators(
2566 testGen,
2567 opName,
2568 dtype,
2569 arg_list,
2570 error_name,
2571 )
2572
2573 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002574 return arg_list
2575
2576 @staticmethod
2577 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
2578 arg_list = []
2579
2580 # Enumerate the output types here
2581 if error_name == ErrorIf.WrongOutputType:
2582 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
2583 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002584 dtypeList = [
2585 DType.BOOL,
2586 DType.INT16,
2587 DType.INT32,
2588 DType.FP16,
2589 DType.BF16,
2590 DType.FP32,
2591 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002592 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002593 dtypeList = [
2594 DType.BOOL,
2595 DType.INT8,
2596 DType.INT32,
2597 DType.FP16,
2598 DType.BF16,
2599 DType.FP32,
2600 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002601 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002602 dtypeList = [
2603 DType.BOOL,
2604 DType.INT8,
2605 DType.INT16,
2606 DType.FP16,
2607 DType.BF16,
2608 DType.FP32,
2609 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002610 elif inDtype == DType.BOOL:
2611 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002612 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00002613 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002614 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00002615 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002616 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00002617 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002618 elif error_name == ErrorIf.WrongInputType:
2619 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002620 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002621 else:
2622 raise Exception("Unexpected input dtype: {}".format(inDtype))
2623
2624 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002625 arg_list.append(
2626 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2627 )
2628
2629 # Now add data generator types
2630 arg_list = TosaArgGen._add_data_generators(
2631 testGen,
2632 opName,
2633 dtype,
2634 arg_list,
2635 error_name,
2636 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002637
2638 return arg_list
2639
2640 @staticmethod
2641 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
2642 arg_list = []
2643
2644 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002645 for outDtype in [
2646 DType.UINT8,
2647 DType.INT8,
2648 DType.INT16,
2649 DType.INT32,
2650 DType.UINT16,
2651 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002652 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002653 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002654 and error_name == ErrorIf.OutputZeroPointNotZero
2655 ):
2656 continue
2657 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002658 outDtype != DType.UINT16
2659 and error_name == ErrorIf.U16OutputZeroPointNotValid
2660 ) or (
2661 inDtype != DType.UINT16
2662 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002663 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002664 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002665 continue
2666 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002667 inDtype == DType.UINT8
2668 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002669 and error_name != ErrorIf.WrongOutputType
2670 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002671 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2672 continue
2673 if (
2674 inDtype not in [DType.INT8, DType.INT16]
2675 and outDtype == DType.UINT8
2676 and error_name != ErrorIf.WrongOutputType
2677 ):
2678 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2679 continue
2680 if (
2681 inDtype == DType.UINT16
2682 and outDtype != DType.INT16
2683 and error_name != ErrorIf.WrongOutputType
2684 ):
2685 # The only output dtype for UINT16 is INT16, skip all others
2686 continue
2687 if (
2688 inDtype != DType.INT16
2689 and outDtype == DType.UINT16
2690 and error_name != ErrorIf.WrongOutputType
2691 ):
2692 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002693 continue
2694 if (
2695 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002696 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002697 ):
2698 continue
2699
2700 for scale32 in [False, True]:
2701 if error_name == ErrorIf.ScaleTrue and not scale32:
2702 continue
2703 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2704 continue
2705 for double_round in [False, True]:
2706 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2707 continue
2708 for per_channel in [False, True]:
2709
2710 if (
2711 inDtype == DType.INT48
2712 and scale32
2713 and error_name != ErrorIf.ScaleTrue
2714 ):
2715 # Illegal condition. Must be scale32=False
2716 continue
2717 if (
2718 double_round
2719 and not scale32
2720 and error_name != ErrorIf.ScaleNotTrue
2721 ):
2722 # Illegal condition. ERROR_IF(!scale32 && double_round)
2723 continue
2724
2725 arg_list.append(
2726 (
2727 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002728 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002729 int(scale32),
2730 int(double_round),
2731 int(per_channel),
2732 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002733 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002734 )
2735 )
2736
2737 return arg_list
2738
2739 @staticmethod
2740 def agMul(testGen, opName, shapeList, dtype, error_name=None):
2741 arg_list = []
2742
2743 if dtype is DType.INT32:
2744 for p in range(testGen.args.num_rand_permutations):
2745
2746 shift = testGen.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002747 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002748 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002749 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002750
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002751 arg_list = TosaArgGen._add_data_generators(
2752 testGen,
2753 opName,
2754 dtype,
2755 arg_list,
2756 error_name,
2757 )
2758 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002759 return arg_list
2760
2761 @staticmethod
2762 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
2763 arg_list = []
2764
2765 arg_list.append(("roundTrue", [True]))
2766 arg_list.append(("roundFalse", [False]))
2767
2768 return arg_list
2769
Luke Hutton57287132023-02-06 14:54:18 +00002770 @staticmethod
2771 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2772 arg_list = []
2773
2774 arg_list.append(("inverseTrue", [True]))
2775 arg_list.append(("inverseFalse", [False]))
2776
2777 return arg_list
2778
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002779 # Helper function for reshape. Gets some factors of a larger number.
2780 @staticmethod
2781 def getFactors(val, start=1):
2782 factors = []
2783
2784 for i in range(start, int(np.sqrt(val)) + 1):
2785 if (val % i) == 0:
2786 factors.append(i)
2787
2788 return factors
2789
2790 @staticmethod
2791 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
2792 arg_list = []
2793
2794 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002795 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002796 factors = TosaArgGen.getFactors(totalElements)
2797
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002798 # Find new shapes up to the number of permutations asked for
2799 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002800 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002801 # Rank from 1 to TOSA_TENSOR_MAX_RANK
2802 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002803 if len(factors) < newRank:
2804 continue
2805
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002806 # escape_counter limits the generation of new shapes to a reasonable time
2807 for escape_counter in range(100):
2808
2809 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002810 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002811 remainingElements = totalElements
2812 shuffledFactors = testGen.rng.permutation(factors)
2813 for i in range(1, newRank):
2814 # pick rank-1 factors
2815 newShape.append(shuffledFactors[0])
2816 remainingElements = remainingElements // shuffledFactors[0]
2817 shuffledFactors = testGen.rng.permutation(
2818 TosaArgGen.getFactors(remainingElements)
2819 )
2820 newShape.append(remainingElements)
2821
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002822 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002823 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002824 for name, args_dict in arg_list:
2825 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002826 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002827 break
2828
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002829 if not duplicate:
2830 outShape = "x".join([str(x) for x in newShape])
2831 arg_list.append(
2832 (
2833 "perm{}_rank{}_out{}".format(p, newRank, outShape),
2834 {"new_shape": newShape},
2835 )
2836 )
2837 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002838 break
2839
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002840 # Now add data generator types
2841 arg_list = TosaArgGen._add_data_generators(
2842 testGen,
2843 opName,
2844 dtype,
2845 arg_list,
2846 error_name,
2847 )
2848
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002849 return arg_list
2850
2851 @staticmethod
2852 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
2853 arg_list = []
2854
2855 ifm_shape = shapeList[0]
2856
2857 if error_name == ErrorIf.IndexOutsideBounds:
2858 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
2859 incorrect_small_index = range(-len(ifm_shape), 0)
2860 permutations = [p for p in itertools.permutations(incorrect_large_index)]
2861 permutations.extend(
2862 [p for p in itertools.permutations(incorrect_small_index)]
2863 )
2864 elif error_name == ErrorIf.IndexUsedTwice:
2865 # Create list with a duplicated index
2866 perm_range = list(range(len(ifm_shape)))
2867 index_choice = testGen.rng.choice(range(len(perm_range)))
2868 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
2869 permutations = [p for p in itertools.permutations(perm_range)]
2870
2871 else:
2872 # Get all permutations
2873 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
2874
2875 # Limit to possible permutations from shape dimension or argument setting
2876 limit = min(len(permutations), testGen.args.num_rand_permutations)
2877
2878 # Get random permutation generator that uses all permutations
2879 random_permutations = testGen.rng.permutation(permutations)
2880
2881 # Create list of required amount of permutations
2882 arg_list = [
2883 ("perm{}".format(p), [random_permutations[p].tolist()])
2884 for p in range(limit)
2885 ]
2886 return arg_list
2887
2888 @staticmethod
2889 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
2890 arg_list = []
2891
2892 ifm_shape = shapeList[0]
2893 rank = len(ifm_shape)
2894
2895 for p in range(testGen.args.num_rand_permutations):
2896 start = []
2897 size = []
2898
2899 valid = True
2900
2901 for i in range(rank):
2902 if ifm_shape[i] > 1:
2903 start.append(testGen.randInt(0, ifm_shape[i]))
2904 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
2905
2906 # Invalid slice size?
2907 if size[i] == 0:
2908 valid = False
2909 else:
2910 start.append(0)
2911 size.append(1)
2912
2913 if valid:
2914 # If ERROR_IF test required then incorrect start, size will be returned
2915 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
2916 testGen, error_name, ifm_shape, start, size
2917 )
evacha017f7d4252024-01-24 12:08:09 +00002918 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
2919 # Now add data generator types
2920 arg_list = TosaArgGen._add_data_generators(
2921 testGen,
2922 opName,
2923 dtype,
2924 arg_list,
2925 error_name,
2926 )
2927 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002928 return arg_list
2929
2930 @staticmethod
2931 def agTile(testGen, opName, shapeList, dtype, error_name=None):
2932 arg_list = []
2933
2934 ifm_shape = shapeList[0]
2935 rank = len(ifm_shape)
2936
2937 for p in range(testGen.args.num_rand_permutations):
2938
2939 # Pick a few random, but small multiple values
2940 # because otherwise this has a tendency to generate
2941 # enormous tensors
2942 multiples = []
2943 for i in range(rank):
2944 if ifm_shape[i] > 1000:
2945 # Multiple of 1 if ifm_shape dimension is large to reduce
2946 # tensor size
2947 multiples.append(1)
2948 elif max(ifm_shape) > 1000:
2949 multiples.append(2)
2950 else:
2951 multiples.append(testGen.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00002952 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002953
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00002954 # Now add data generator types
2955 arg_list = TosaArgGen._add_data_generators(
2956 testGen,
2957 opName,
2958 dtype,
2959 arg_list,
2960 error_name,
2961 )
2962 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002963 return arg_list
2964
2965 @staticmethod
2966 def agResize(testGen, opName, shapeList, dtype, error_name=None):
2967 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002968 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002969
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002970 def get_aspect_ratio_resize_params():
2971 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
2972 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
2973 invert = testGen.rng.choice((False, True))
2974 letterbox = testGen.rng.choice((False, True))
2975
2976 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
2977 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
2978 scale_y_d = scale_x_d = 1
2979 offset_x = offset_y = 0
2980
2981 if letterbox:
2982 max_border = scale_y_n
2983 border_y = testGen.randInt(low=0, high=max_border)
2984 border_x = 0
2985 else:
2986 # Pillarboxing
2987 border_y = 0
2988 max_border = scale_x_n
2989 border_x = testGen.randInt(low=0, high=max_border)
2990
2991 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2992 offset = (offset_y, offset_x)
2993 border = (border_y, border_x)
2994
2995 return scale, offset, border
2996
2997 def get_upscale_downscale_params():
2998 valid_params = False
2999 while not valid_params:
3000 upscale = testGen.rng.choice((False, True))
3001
3002 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
3003 origin_sampling = testGen.rng.choice((False, True))
3004
3005 if upscale:
3006 shift = testGen.randInt(low=1, high=4)
3007 scale_x_d = scale_y_d = 1
3008 scale_x_n = scale_y_n = (
3009 1 << shift if origin_sampling else 2 << shift
3010 )
3011 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3012 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3013 else:
3014 scale_x_n = 1
3015 scale_y_n = 1
3016
3017 # Return list of valid scale_*_d values (max value 4) given input dim shape
3018 def get_valid_denom(ifm_dim):
3019 return [x for x in range(1, 5) if ifm_dim % x == 1]
3020
3021 # Generate list of valid downscale values and choose one randomly
3022 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3023 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3024
3025 if not valid_scale_y_ds and not valid_scale_x_ds:
3026 # Bad parameters, skip
3027 continue
3028
3029 if not valid_scale_y_ds:
3030 scale_y_d = 1
3031 else:
3032 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
3033
3034 if not valid_scale_x_ds:
3035 scale_x_d = 1
3036 else:
3037 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
3038
3039 border_x = border_y = 0
3040 offset_y = testGen.randInt(0, 16 * scale_y_n)
3041 offset_x = testGen.randInt(0, 16 * scale_x_n)
3042 valid_params = True
3043
3044 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3045 offset = (offset_y, offset_x)
3046 border = (border_y, border_x)
3047 return scale, offset, border
3048
3049 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003050 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3051 scale = scale_n / scale_d
3052 if scale > max_scale:
3053 factor = scale / max_scale
3054 new_scale_d = math.ceil(scale_d * factor)
3055 assert scale_n / new_scale_d <= max_scale
3056 scale_d = new_scale_d
3057 return scale_d
3058
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003059 # Scale
3060 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
3061 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
3062
3063 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
3064 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
3065
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003066 scale_y_d = fix_scale_to_max_scale(
3067 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3068 )
3069 scale_x_d = fix_scale_to_max_scale(
3070 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3071 )
3072
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003073 # Offsets and border within the scale
3074 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3075 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3076 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3077 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
3078
3079 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3080 offset = (offset_y, offset_x)
3081 border = (border_y, border_x)
3082 return scale, offset, border
3083
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003084 def get_level_8k_params():
3085 # Create 64x scale - 64/1 to 2048/32
3086 scale_d = testGen.randInt(
3087 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3088 )
3089 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3090 # Create half to fifth scaling
3091 scale_d_alt = testGen.randInt(low=2, high=6)
3092 scale_n_alt = 1
3093 switch = testGen.rng.choice((False, True))
3094 if switch:
3095 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3096 else:
3097 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3098
3099 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3100 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
3101 offset = (offset_y, offset_x)
3102 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
3103 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
3104 border = (border_y, border_x)
3105 return scale, offset, border
3106
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003107 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003108 # Exclude illegal {mode, type} configurations. Pick legal output types
3109 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3110 outputDTypeList = [DType.INT8]
3111 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3112 outputDTypeList = [DType.INT16]
3113 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3114 outputDTypeList = [DType.INT32]
3115 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3116 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003117 elif dtype == DType.FP16:
3118 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003119 elif dtype == DType.BF16:
3120 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003121 elif dtype == DType.FP32:
3122 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003123 elif error_name == ErrorIf.WrongInputType:
3124 # If an incorrect input type is used then we set a 'correct'
3125 # output type to avoid other errors
3126 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3127 else:
3128 continue
3129
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003130 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3131
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003132 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003133 perm = 0
3134 while perm < testGen.args.num_rand_permutations:
3135 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003136 if not testGen.args.level8k:
3137 _rnd_param_fn = testGen.rng.choice(
3138 (
3139 get_rand_params,
3140 get_upscale_downscale_params,
3141 get_aspect_ratio_resize_params,
3142 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003143 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003144 scale, offset, border = _rnd_param_fn()
3145 else:
3146 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003147
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003148 # Expand params for bounds-checking
3149 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3150 (offset_y, offset_x) = offset
3151 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003152
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003153 # Make sure output dimensions OH and OW are integers
3154 partial_output_y = (
3155 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3156 )
3157 partial_output_x = (
3158 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3159 )
3160 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003161 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003162 if (
3163 partial_output_y % scale_y_d == 0
3164 and partial_output_x % scale_x_d == 0
3165 ):
3166 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003167 if perm > 0:
3168 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003169 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003170 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003171 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003172 while partial_output_y % scale_y_d != 0:
3173 scale_y_d -= 1
3174 while partial_output_x % scale_x_d != 0:
3175 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003176 # Make sure we are still within max scaling
3177 if (
3178 scale_y_n / scale_y_d
3179 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3180 scale_x_n / scale_x_d
3181 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3182 # Skip the test as it is using too large a scaling factor
3183 if perm > 0:
3184 perm += 1
3185 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003186
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003187 output_y = partial_output_y // scale_y_d + 1
3188 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003189
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003190 if (
3191 output_y >= testGen.args.max_resize_output_dim
3192 or output_x >= testGen.args.max_resize_output_dim
3193 ) and error_name is None:
3194 # Skip positive test if output dim will be too high
3195 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003196 if not testGen.args.level8k or perm > 0:
3197 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003198 continue
3199
3200 if (
3201 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003202 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003203 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003204 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003205 ):
3206 # Output dimensions out of scope
3207 if error_name is not None and perm > 0:
3208 # As long as we have one ERROR_IF test, don't worry
3209 # about creating all the other permutations
3210 perm += 1
3211 continue
3212
3213 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3214 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003215 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003216 and output_y - scale_y_d < 1
3217 )
3218 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003219 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003220 and output_x - scale_x_d < 1
3221 )
3222 ):
3223 # Can't create a negative test with these params as it
3224 # will create invalid output size
3225 if perm > 0:
3226 perm += 1
3227 continue
3228
3229 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3230 offset = [offset_y, offset_x]
3231 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003232
3233 # Common for all data types
3234 if error_name is not None:
3235 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003236 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003237 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003238 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003239 outputDTypeNew,
3240 ) = TosaErrorIfArgGen.eiResizeErrorIf(
3241 testGen,
3242 error_name,
3243 mode,
3244 dtype,
3245 shapeList,
3246 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003247 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003248 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003249 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003250 )
3251 else:
3252 outputDTypeNew = outputDType
3253
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003254 arg_to_append = (
3255 arg_str.format(
3256 "N" if mode == ResizeMode.NEAREST else "B",
3257 testGen.typeStr(outputDTypeNew),
3258 scale[0],
3259 scale[1],
3260 scale[2],
3261 scale[3],
3262 offset[0],
3263 offset[1],
3264 border[0],
3265 border[1],
3266 ),
3267 [
3268 mode,
3269 scale,
3270 offset,
3271 border,
3272 dtype,
3273 outputDTypeNew,
3274 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003275 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003276 if arg_to_append in arg_list:
3277 # Skip already generated test params
3278 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003279
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003280 # Valid permutation
3281 perm += 1
3282 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003283 return arg_list
3284
3285 @staticmethod
3286 def agTable(testGen, opName, shapeList, dtype, error_name=None):
3287 arg_list = []
3288
3289 if dtype == DType.INT8:
3290 table = np.int32(
3291 testGen.rng.integers(low=-128, high=128, size=[256])
3292 ).tolist()
3293 else: # INT16
3294 table = np.int32(
3295 testGen.rng.integers(low=-32768, high=32768, size=[513])
3296 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003297 # Make sure all slopes are within REQUIRE min/max 16-bit int
3298 for idx in range(len(table) - 1):
3299 slope = table[idx + 1] - table[idx]
3300 # Alter the next table entry to force the slope to be ok
3301 if slope > 32767:
3302 table[idx + 1] -= slope - 32767
3303 if slope < -32768:
3304 table[idx + 1] -= slope + 32768
3305 slope = table[idx + 1] - table[idx]
3306 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003307 arg_list.append(
3308 (
3309 "",
3310 [table],
3311 )
3312 )
3313 return arg_list
3314
3315 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
3316 # CondIf generates the condition values here.
3317 # Convert to tensors in the build function, along with the
3318 # then and else blocks
3319 arg_list = []
3320
3321 for c in [False, True]:
3322 arg_list.append(("cond{}".format(int(c)), [c]))
3323
3324 return arg_list
3325
3326 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
3327 # While loop: 0 iterations, 1, more than 1
3328 arg_list = []
3329
3330 for iter in [0, 1, 4]:
3331 arg_list.append(("iter{}".format(iter), [iter]))
3332
3333 return arg_list