blob: 48639565e1c057c871777a2862a582afafc19e7b [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
19
20class TosaQuantGen:
21 """QuantizedInfo random generator helper functions.
22
23 Specify with 'qgen': in the operator defintion.
24 """
25
26 def __init__(self):
27 pass
28
29 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010031
32 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010033 if testGen.args.zeropoint is not None:
34 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035 return testGen.randInt(-128, 128)
36 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010037 if testGen.args.zeropoint is not None:
38 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 return testGen.randInt(0, 256)
40 elif error_name in [
41 ErrorIf.InputZeroPointNotZero,
42 ErrorIf.WeightZeroPointNotZero,
43 ErrorIf.OutputZeroPointNotZero,
44 ]:
45 zero_point = testGen.randInt(-128, 128)
46 if zero_point == 0:
47 zero_point = 1
48 return zero_point
49 return 0
50
51 @staticmethod
52 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010053 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000054 qinfo = [
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 TosaQuantGen.getZeroPoint(testGen, dtype),
57 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010058 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000059 qinfo = [
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
62 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010063 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000064 qinfo = [
65 TosaQuantGen.getZeroPoint(testGen, dtype),
66 TosaQuantGen.getZeroPoint(testGen, dtype),
67 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068 return qinfo
69
70 @staticmethod
71 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010072 if isinstance(dtype_or_dtypeList, list):
73 # a list of [input, weights, accumulator] dtypes
74 dtypeList = dtype_or_dtypeList
75 else:
76 # an int, [input, weights, accumulator] dtypes are the same
77 dtypeList = [dtype_or_dtypeList] * 3
78
79 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000080 qinfo = [
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
82 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
83 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000085 qinfo = [
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
87 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
88 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000090 qinfo = [
91 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
92 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
93 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 return qinfo
95
96 @staticmethod
97 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000099 qinfo = [
100 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
101 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
102 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000104 qinfo = [
105 TosaQuantGen.getZeroPoint(testGen, dtype),
106 TosaQuantGen.getZeroPoint(testGen, dtype),
107 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108 return qinfo
109
110 @staticmethod
111 def computeMultiplierAndShift(scaleFp, scale32):
112 # Derived from computeMultiplierAndShiftTosaScale32
113 # Provide a floating-point scaling factor and the scale32 parameter
114 # to compute the multiplier and shift
115
116 if scale32:
117 scaleBits = 31
118 else:
119 scaleBits = 15
120
121 m, shift = math.frexp(scaleFp)
122
123 if scaleFp < 0.0:
124 m = -m
125
126 multiplier = round(m * (1 << scaleBits))
127 assert multiplier <= (1 << scaleBits)
128
129 if multiplier == (1 << scaleBits):
130 multiplier = multiplier // 2
131 shift = shift + 1
132
133 shift = (-shift) + scaleBits
134 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
135 # scaleFp, scaleBits, m, multiplier, shift))
136
137 # Adjust multiplier such that shift is in allowed value range.
138 if shift == 0:
139 multiplier = multiplier // 4
140 shift = shift + 2
141 elif shift == 1:
142 multiplier = multiplier // 2
143 shift = shift + 1
144 elif shift == 63:
145 multiplier = multiplier * 2
146 shift = shift - 1
147
148 assert multiplier <= (1 << scaleBits)
149 assert shift >= 2 and shift <= 62
150
151 return multiplier, shift
152
153
154class TosaTensorGen:
155 """Tensor generators create a shape list for the placeholder and const tensor
156 data operands for the operator.
157
158 The actual random data is generated separately for each test.
159 """
160
161 def __init__(self):
162 pass
163
164 @staticmethod
165 def tgBasic(testGen, opName, rank, error_name=None):
166 pl, const = opName["operands"]
167 shape = testGen.makeShape(rank)
168
169 # Constrict the overall size of the shape when creating ERROR_IF tests
170 if error_name:
171 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
172
173 shape_list = []
174 for i in range(pl + const):
175 shape_list.append(shape.copy())
176
Luke Huttona4e48ca2023-02-22 11:53:48 +0000177 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 if error_name == ErrorIf.RankMismatch:
179 if rank == 1 and i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
181 elif i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
183
184 return shape_list
185
186 @staticmethod
187 def tgNHWC(testGen, opName, rank, error_name=None):
188 pl, const = opName["operands"]
189
190 if error_name != ErrorIf.WrongRank:
191 assert rank == 4
192
193 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000194 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name and error_name != ErrorIf.MaxDimExceeded:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000207 def tgGather(testGen, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100208 pl, const = opName["operands"]
209
210 assert pl == 2
211 assert const == 0
212 if error_name != ErrorIf.WrongRank:
213 assert rank == 3
214
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000215 values_shape = testGen.makeShape(rank)
216 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100217
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000218 N = values_shape[0]
219 W = testGen.makeDimension()
220 indices_shape = [N, W]
221
222 shape_list = [values_shape, indices_shape]
223 return shape_list
224
225 @staticmethod
226 def tgScatter(testGen, opName, rank, error_name=None):
227 pl, const = opName["operands"]
228
229 assert pl == 3
230 assert const == 0
231 if error_name != ErrorIf.WrongRank:
232 assert rank == 3
233
234 values_in_shape = testGen.makeShape(rank)
235 values_in_shape = testGen.constrictBatchSize(values_in_shape)
236
237 N = values_in_shape[0]
238 K = values_in_shape[1]
239 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100240
Jeremy Johnson194fe312023-12-07 14:17:57 +0000241 # Make sure W is not greater than K, as we can only write each output index
242 # once (having a W greater than K means that you have to repeat a K index)
243 W_min = min(testGen.args.tensor_shape_range[0], K)
244 W_max = min(testGen.args.tensor_shape_range[1], K)
245 W = testGen.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100246
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000247 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100248
249 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000250 shape_list.append(values_in_shape)
251 shape_list.append([N, W]) # indices
252 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100253
254 return shape_list
255
256 @staticmethod
257 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
258 shape = testGen.makeShape(rank)
259
260 pl, const = op["operands"]
261
262 shape_list = []
263
264 # Choose one of the inputs to broadcast
265 # Note: Simplifies OutputShaper code if we don't change first shape for errors
266 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Jerry Ge135c9552023-05-23 20:59:32 +0000267 fuzz_idx = testGen.randInt(0, rank)
268
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100269 for i in range(pl + const):
270 shape_bcast = shape.copy()
271
Jerry Ge135c9552023-05-23 20:59:32 +0000272 # To test broadcasting, the chosen fuzz index dimension should not be 1
273 if shape_bcast[fuzz_idx] == 1:
274 shape_bcast[fuzz_idx] += 1
275
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100276 # If the chosen input, pick a random index to broadcast
277 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000278 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100279 # Add one rank to the shape (or more for rank of 1)
280 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
281 shape_bcast = np.concatenate(
282 (shape_bcast, testGen.makeShape(extra_ranks))
283 )
284 if rank != 1:
285 # Either keep the extra rank, or remove it
286 new_len = testGen.rng.choice([-2, len(shape_bcast)])
287 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000288 elif error_name == ErrorIf.BroadcastShapesMismatch:
289 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100290 else:
291 shape_bcast[fuzz_idx] = 1
292
293 shape_list.append(shape_bcast)
294
295 return shape_list
296
297 @staticmethod
298 def tgConv2D(testGen, op, rank, error_name=None):
299 pl, const = op["operands"]
300
301 if error_name != ErrorIf.WrongRank:
302 assert rank == 4
303
304 # IFM dimensions are NHWC
305 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000306 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100307
308 # Constrict the overall size of the shape when creating ERROR_IF tests
309 if error_name:
310 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
311 ifm_shape, max_dim=24, max_items=10000
312 )
313
314 # Get the filter height/width from the operator parameters
315 filter_hw = op["filter"]
316
317 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000318 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100319
320 # The filter dimensions are OHWI
321 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
322
323 # The bias is OC
324 bias_shape = np.asarray([ofm_depth])
325
326 return [ifm_shape, filter_shape, bias_shape]
327
328 @staticmethod
329 def tgConv3D(testGen, op, rank, error_name=None):
330 pl, const = op["operands"]
331
332 if error_name != ErrorIf.WrongRank:
333 assert rank == 5
334
335 # IFM dimensions are NDHWC
336 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000337 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100338
339 # Constrict the overall size of the shape when creating ERROR_IF tests
340 if error_name:
341 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
342 ifm_shape, max_dim=24, max_items=10000
343 )
344
345 # Get the filter depth/height/width from the operator parameters
346 filter_dhw = op["filter"]
347
348 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000349 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100350
351 # The filter dimensions are ODHWI
352 filter_shape = np.asarray(
353 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
354 )
355
356 # The bias is OC
357 bias_shape = np.asarray([ofm_channel])
358
359 return [ifm_shape, filter_shape, bias_shape]
360
361 @staticmethod
362 def tgTransposeConv2D(testGen, op, rank, error_name=None):
363 pl, const = op["operands"]
364
365 if error_name != ErrorIf.WrongRank:
366 assert rank == 4
367
368 # IFM dimensions are NHWC
369 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000370 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100371
372 # Constrict the overall size of the shape when creating ERROR_IF tests
373 if error_name:
374 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
375 ifm_shape, max_dim=24, max_items=10000
376 )
377
378 # Get the filter height/width from the operator parameters
379 filter_hw = op["filter"]
380
381 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000382 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100383
384 # The filter dimensions are OHWI
385 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
386
387 # The bias is OC
388 bias_shape = np.asarray([ofm_depth])
389
390 return [ifm_shape, filter_shape, bias_shape]
391
392 @staticmethod
393 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
394 pl, const = op["operands"]
395
396 if error_name != ErrorIf.WrongRank:
397 assert rank == 4
398 assert pl == 1 and const == 2
399
400 # IFM dimensions are NHWC
401 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000402 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100403
404 # Constrict the overall size of the shape when creating ERROR_IF tests
405 if error_name:
406 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
407 ifm_shape, max_dim=24, max_items=10000
408 )
409
410 # Get the filter height/width from the operator parameters
411 # Filter is KH, HW, C, M
412 filter_hw = op["filter"]
413
414 # Generate a random OFM depth, but don't let it get too big because
415 # the output depth is M * C
416 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000417 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100418 ) + 1
419
420 # The filter dimensions are HWCM
421 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
422
423 # The bias is M * C
424 bias_shape = np.asarray([ifm_shape[3] * filter_m])
425
426 return [ifm_shape, filter_shape, bias_shape]
427
428 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000429 def tgFFT2d(testGen, op, rank, error_name=None):
430 pl, const = op["operands"]
431
432 if error_name != ErrorIf.WrongRank:
433 assert rank == 3
434 assert pl == 2 and const == 0
435
436 # IFM dimensions are NHW
437 ifm_shape = testGen.makeShape(rank)
438
439 # Select nearest lower power of two from input height and width
440 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
441 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
442
443 # Constrict the overall size of the shape when creating ERROR_IF tests
444 if error_name:
445 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
446
447 # Generate an invalid kernel that is not a power of two
448 if error_name == ErrorIf.KernelNotPowerOfTwo:
449 inc_h = 2 if ifm_shape[1] == 1 else 1
450 inc_w = 2 if ifm_shape[2] == 1 else 1
451 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
452 selected_inc = testGen.rng.choice(inc_choices)
453 ifm_shape[1] += selected_inc[0]
454 ifm_shape[2] += selected_inc[1]
455
456 ifm_shape = testGen.constrictBatchSize(ifm_shape)
457
458 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
459 if error_name == ErrorIf.FFTInputShapeMismatch:
460 modify_shape = testGen.rng.choice([0, 1])
461 # Only modify kernel (H, W)
462 modify_dim = testGen.rng.choice([1, 2])
463 ifm_shapes[modify_shape][modify_dim] *= 2
464
465 return [ifm_shapes[0], ifm_shapes[1]]
466
467 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000468 def tgRFFT2d(testGen, op, rank, error_name=None):
469 pl, const = op["operands"]
470
471 if error_name != ErrorIf.WrongRank:
472 assert rank == 3
473 assert pl == 1 and const == 0
474
475 # IFM dimensions are NHW
476 ifm_shape = testGen.makeShape(rank)
477
478 # Select nearest lower power of two from input height and width
479 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
480 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
481
482 # Constrict the overall size of the shape when creating ERROR_IF tests
483 if error_name:
484 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
485
486 # Generate an invalid kernel that is not a power of two
487 if error_name == ErrorIf.KernelNotPowerOfTwo:
488 # We must increment by 2 if current size is 1
489 inc_h = 2 if ifm_shape[1] == 1 else 1
490 inc_w = 2 if ifm_shape[2] == 1 else 1
491 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
492 selected_inc = testGen.rng.choice(inc_choices)
493 ifm_shape[1] += selected_inc[0]
494 ifm_shape[2] += selected_inc[1]
495
James Ward30124a82023-02-02 14:56:33 +0000496 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000497
498 return [ifm_shape]
499
500 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100501 def tgFullyConnected(testGen, op, rank, error_name=None):
502 pl, const = op["operands"]
503
504 if error_name != ErrorIf.WrongRank:
505 assert rank == 2
506
507 input_shape = testGen.makeShape(rank)
508
509 # Constrict the overall size of the shape when creating ERROR_IF tests
510 if error_name:
511 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
512
513 filter_oc = testGen.rng.integers(
514 low=testGen.args.tensor_shape_range[0],
515 high=testGen.args.tensor_shape_range[1],
516 size=1,
517 )[0]
518 filter_shape = np.asarray([filter_oc, input_shape[1]])
519
520 bias_shape = np.asarray([filter_oc])
521
522 return [input_shape, filter_shape, bias_shape]
523
524 @staticmethod
525 def tgMatmul(testGen, op, rank, error_name=None):
526 pl, const = op["operands"]
527
528 if error_name != ErrorIf.WrongRank:
529 assert rank == 3
530 assert pl == 2 and const == 0
531
532 a_shape = testGen.makeShape(rank)
533
534 # Constrict the overall size of the shape when creating ERROR_IF tests
535 if error_name:
536 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
537
538 # Get a random number for b_oc even if target shape is defined
539 b_oc = np.int32(
540 testGen.rng.integers(
541 low=testGen.args.tensor_shape_range[0],
542 high=testGen.args.tensor_shape_range[1],
543 size=1,
544 )
545 )[0]
546 # If N or H is large let b_oc be 1 to reduce output tensor size
547 if max(a_shape) > 1000:
548 b_oc = 1
549
550 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
551 return [a_shape, b_shape]
552
553 @staticmethod
554 def tgConcat(testGen, opName, rank, error_name=None):
555 pl, const = opName["operands"]
556 shape = testGen.makeShape(rank)
557
558 # Create extra tensors to concat.
559 # Take into account value of pl when getting maximum number of concats
560 num_tensors = testGen.randInt(0, 4)
561 shape_list = []
562 for i in range(pl + const + num_tensors):
563 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
564 remove = testGen.rng.choice([True, False])
565 wrongShape = shape.copy()
566
567 if remove and len(shape) > 1:
568 wrongShape = wrongShape[1:]
569 else:
570 wrongShape = list(wrongShape)
571 wrongShape.append(testGen.rng.integers(1, 10))
572
573 shape_list.append(wrongShape)
574 else:
575 shape_list.append(shape.copy())
576
577 return shape_list
578
579 @staticmethod
580 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
581 if error_name in [
582 ErrorIf.AxisSmallerZero,
583 ErrorIf.AxisLargerRank,
584 ErrorIf.ConcatInputRankMismatch,
585 ]:
586 return shapeList
587
588 # Split concat shape along axis to allow for multiple const inputs
589 # without making too many large tensors
590 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
591 # If axis can't be split we still need to invalidate other dimensions
592 if error_name == ErrorIf.ConcatInputDimMismatch:
593 for shape in shapeList[1:]:
594 # Negative test shapeLists are created individually for each test,
595 # so no need to copy the shape before altering it.
596 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
597 return shapeList
598
599 # Create copy of shape we are going to split (so we don't alter shapeList)
600 shape = shapeList[0].copy()
601 # Add original shape as first input
602 new_shapeList = [shape.copy()]
603 length_on_axis = shape[axis]
604 remaining_length = length_on_axis
605 for i in range(len(shapeList) - 2):
606 # Calculate split on axis and remaining value
607 split_shape_val = int(shape[axis] / 2)
608 remaining_length = remaining_length - split_shape_val
609
610 # Append new shape, and set remaining shape
611 shape[axis] = split_shape_val
612 new_shapeList.append(shape.copy())
613
614 # invalidate dimensions
615 if error_name == ErrorIf.ConcatInputDimMismatch:
616 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
617 else:
618 shape[axis] = remaining_length
619
620 if i == len(shapeList) - 3:
621 new_shapeList.append(shape.copy())
622
623 return new_shapeList
624
Won Jeon74342e52024-01-09 00:34:40 +0000625 @staticmethod
626 def tgShape(testGen, opName, rank, error_name=None):
627 pl, const = opName["operands"]
628 shape = [rank]
629
630 # Constrict the overall size of the shape when creating ERROR_IF tests
631 if error_name:
632 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
633
634 shape_list = []
635 for i in range(pl + const):
636 shape_list.append(shape.copy())
637
638 # Generates an input rank mismatch for operators with more than one input
639 if error_name == ErrorIf.RankMismatch:
640 if rank == 1 and i != 1:
641 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
642 elif i != 1:
643 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
644
645 return shape_list
646
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100647
648class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100649 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100650
651 def __init__(self):
652 pass
653
Jeremy Johnson1271c442023-09-05 11:39:26 +0100654 class TVGInfo:
655 """Enhanced tensor values information including data gen dict."""
656
657 def __init__(self, tensorList, dataGenDict):
658 self.tensorList = tensorList
659 self.dataGenDict = dataGenDict
660
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100661 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000662 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100663 pCount, cCount = op["operands"]
664
665 tens = []
666 tens.extend(
667 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
668 )
669 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
670
671 return tens
672
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100673 # Default high value for random numbers
674 TVG_FLOAT_HIGH_VALUE = {
675 DType.FP32: (1 << 128) - (1 << (127 - 23)),
676 DType.FP16: (1 << 16) - (1 << (15 - 10)),
677 DType.BF16: (1 << 128) - (1 << (127 - 7)),
678 }
679
Jeremy Johnson30476252023-11-20 16:15:30 +0000680 # Default lowest normal values for random numbers
681 TVG_FLOAT_LOW_VALUE = {
682 DType.FP32: np.exp2(-126),
683 DType.FP16: np.exp2(-14),
684 DType.BF16: np.exp2(-126),
685 }
686
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100687 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +0000688 def _get_data_range(testGen, dtype, highValueLookup, lowValueLookup=None):
689 # Return a tuple of (low,high) data range values for the given data
690 # type using a combination of per operator table limits, data limits
691 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000692 if dtype in highValueLookup:
Jeremy Johnson30476252023-11-20 16:15:30 +0000693 type_range = testGen.getDTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000694 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000695 if lowValueLookup is not None and dtype in lowValueLookup:
696 low_val = lowValueLookup[dtype]
697 else:
698 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000699 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000700 # respecting the default ranges if more/less than the low/high
701 # values
702 data_range = (
703 max(low_val, type_range[0]),
704 min(high_val, type_range[1]),
705 )
706 if data_range[0] > data_range[1]:
707 # Invalid data range from low to high created due to user
708 # constraints revert to using internal ranges as they are
709 # known to work
710 msg = f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
711 warnings.warn(msg)
712 data_range = (low_val, high_val)
713 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000714 return None
715
716 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100717 def tvgLazyGenDefault(
718 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
719 ):
720 # Variable inputs versus constants
721 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000722 if "p_count" in argsDict:
723 # Override for operators like CONCAT
724 pCount = argsDict["p_count"]
725 cCount = argsDict["c_count"]
726 assert pCount + cCount == len(
727 shapeList
728 ), "Placeholders & Constant tensors must match shapes list"
729
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000730 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100731
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100732 if (
733 error_name is not None
734 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100735 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100736 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000737 # Fall back to internal data gen when dealing with unsupported types or ops
738 data_range = argsDict["data_range"] if "data_range" in argsDict else None
739 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000740 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000741 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000742 if "data_range_list" in argsDict:
743 data_range = argsDict["data_range_list"][idx]["range"]
744 roundMode = (
745 "round" in argsDict["data_range_list"][idx]
746 and argsDict["data_range_list"][idx]["round"] is True
747 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000748 if data_range is not None and dtype not in (
749 DType.FP16,
750 DType.FP32,
751 DType.BF16,
752 ):
753 # Change from inclusive to exclusive range
754 data_range = (data_range[0], data_range[1] + 1)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000755 # Ignore lazy data gen option and create data array using any range limits
Won Jeon64e4bfe2024-01-18 06:31:55 +0000756
757 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
758 arr = np.int64(argsDict["fixed_data"][idx])
759 else:
760 arr = testGen.getRandTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000761 if roundMode:
762 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000763 if idx < pCount:
764 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
765 else:
766 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100767
Jeremy Johnson1271c442023-09-05 11:39:26 +0100768 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
769
770 # Create data generator meta-data
771 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100772 tens_data = {
773 "version": "0.1",
774 "tensors": {},
775 }
776 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100777 for idx, shape in enumerate(shapeList):
778
779 tens_meta = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000780 if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
781 tens_meta["generator"] = gtu.DataGenType(
782 gtu.DataGenType.FIXED_DATA
783 ).name
784 else:
785 tens_meta["generator"] = gtu.DataGenType(dg_type).name
786
Jeremy Johnson1271c442023-09-05 11:39:26 +0100787 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
788 tens_meta["shape"] = [int(i) for i in shape]
789 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100790 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100791
792 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100793 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100794 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100795 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100796
797 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
798 info = {}
Won Jeon64e4bfe2024-01-18 06:31:55 +0000799 if (
800 tens_meta["generator"]
801 == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name
802 ):
803 info["data"] = [int(i) for i in argsDict["fixed_data"][idx]]
804 tens_meta["fixed_data_info"] = info
805 else:
806 # TODO - generate seed for this generator based on test
807 info["rng_seed"] = 42
Jeremy Johnson30476252023-11-20 16:15:30 +0000808
Won Jeon64e4bfe2024-01-18 06:31:55 +0000809 data_range = None
810 if "data_range_list" in argsDict:
811 data_range = argsDict["data_range_list"][idx]["range"]
812 if "round" in argsDict["data_range_list"][idx]:
813 info["round"] = argsDict["data_range_list"][idx]["round"]
814 elif "data_range" in argsDict:
815 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000816
Won Jeon64e4bfe2024-01-18 06:31:55 +0000817 if data_range is None:
818 data_range = testGen.getDTypeRange(
819 dtypeList[idx], high_inclusive=True
820 )
821 info["range"] = [str(v) for v in data_range]
822 tens_meta["pseudo_random_info"] = info
Jeremy Johnson1271c442023-09-05 11:39:26 +0100823 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
824 info = {}
825 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100826 info["ks"] = int(argsDict["ks"])
827 if "acc_type" in argsDict:
828 # Convert type number into JSON name
829 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
830 "json"
831 ]
832 if "kernel" in argsDict:
833 info["kernel"] = [int(k) for k in argsDict["kernel"]]
834 if "axis" in argsDict:
835 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100836 tens_meta["dot_product_info"] = info
837 else:
838 # TODO - other data gen type
839 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100840
841 # Using the finished generate config meta data - generate the data if
842 # needed and assign a tensor name from the serializer
843
844 # Need to generate data when not lazy or for the bias tensor as we need
845 # to work out if the bias data is non-zero for compliance
846 if not testGen.args.lazy_data_gen or (
847 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
848 ):
849 # Give this tensor a temporary name until we get one from the serializer
850 temp_name = f"placeholder_{idx}"
851 dg_tens_meta[temp_name] = tens_meta
852 # Create data now using the temporary name to access meta details
853 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
Won Jeon64e4bfe2024-01-18 06:31:55 +0000854 if tens_meta["data_type"] == "SHAPE":
855 # Tensor type SHAPE and Numpy file type must be the same
856 data = np.int64(data)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100857 # Remove the item as we will give it the correct name later
858 del dg_tens_meta[temp_name]
859
860 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
861 # The KS value used by compliance verification is altered when the
862 # bias data is non-zero
863 if max(abs(data)) > 0.0:
864 argsDict["ksb"] = argsDict["ks"] + 1
865
866 if testGen.args.lazy_data_gen:
867 data = None
868
869 if tens_meta["input_type"] == "VARIABLE":
870 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
871 else:
872 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
873
874 tens_ser_list.append(tens)
875 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100876 dg_tens_meta[tens.name] = tens_meta
877
Jeremy Johnson1271c442023-09-05 11:39:26 +0100878 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
879
880 @staticmethod
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000881 def tvgNegate(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100882 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000883 # Integer test
884 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100885 pCount, cCount = op["operands"]
886 assert (
887 pCount == 1 and cCount == 0
888 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100889 # Must create tensors with values within accumulator (int32) negatable
890 # range
891 max_val = (1 << 31) - 1
892 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100893 arr = np.int32(
894 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
895 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000896 tens_ser_list = []
897 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100898 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
899 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000900 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100901 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000902 # ERROR_IF or floating point test
903 return TosaTensorValuesGen.tvgLazyGenDefault(
904 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100905 )
906
Jeremy Johnson30476252023-11-20 16:15:30 +0000907 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000908 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
909 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
910 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
911 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
912 }
913
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100914 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000915 def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Won Jeon74342e52024-01-09 00:34:40 +0000916 if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000917 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100918 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000919 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100920 pCount, cCount = op["operands"]
921 assert (
922 pCount == 2 and cCount == 0
923 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000924 tens_ser_list = []
Won Jeon74342e52024-01-09 00:34:40 +0000925 add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
926 data_range = testGen.args.tensor_shape_range
927 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0], data_range)
928 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1], data_range)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100929 if add:
930 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
931 else:
932 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
933
934 # Work out the saturation limits
935 max_i32 = (1 << 31) - 1
936 min_i32 = -(1 << 31)
937 max_arr = np.full(shapeList[1], max_i32)
938 min_arr = np.full(shapeList[1], min_i32)
939
940 # Find how much values exceed the maximum/minimums
941 sat_max_arr = np.maximum(res_arr - max_arr, 0)
942 sat_min_arr = np.minimum(res_arr - min_arr, 0)
943
944 if not add:
945 # Swap saturation values and negate values as we need to perform opposite operations
946 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
947
948 # Create new array of unsaturated values by clipping values as needed
949 b_unsat_arr = b_arr
950 if (sat_max_arr != 0).any():
951 # Clip values that cause saturation
952 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
953 # Reduce axes in unsaturated tensor to match original tensor
954 for axis, dim in enumerate(b_arr.shape):
955 if dim != b_unsat_arr.shape[axis]:
956 assert (
957 dim == 1
958 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
959 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
960
961 if (sat_min_arr != 0).any():
962 # Clip values that cause saturation
963 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
964 # Reduce axes in unsaturated tensor to match original tensor
965 for axis, dim in enumerate(b_arr.shape):
966 if dim != b_unsat_arr.shape[axis]:
967 assert (
968 dim == 1
969 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
970 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
971
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000972 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100973 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
974 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000975 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100976 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
977 )
978
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000979 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100980 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000981 # ERROR_IF or floating point test
982 data_range = TosaTensorValuesGen._get_data_range(
983 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
984 )
985 if data_range:
986 argsDict["data_range"] = data_range
987
988 return TosaTensorValuesGen.tvgLazyGenDefault(
989 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100990 )
991
992 @staticmethod
993 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000994 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100995 ):
996 if dtypeList[0] in (
997 DType.INT32,
998 DType.INT16,
999 DType.INT8,
1000 ):
1001 # Limit input tensors with cond_if_binary or while_loop to stop
1002 # saturation of add/sub ops with int32 and keep all logical shift
1003 # values between 0 to 31 for int16 or int8
1004 pCount, cCount = op["operands"]
1005 pRemain = pCount
1006 placeholders = []
1007 for idx, shape in enumerate(shapeList[:]):
1008 if dtypeList[0] == DType.INT32:
1009 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
1010 else:
1011 arr = np.int32(
1012 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
1013 )
1014 if pRemain > 0:
1015 placeholders.append(
1016 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
1017 )
1018 pRemain -= 1
1019 else:
1020 placeholders.append(
1021 testGen.ser.addConst(shape, dtypeList[idx], arr)
1022 )
1023
1024 return placeholders
1025 else:
1026 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001027 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001028 )
1029
1030 @staticmethod
1031 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001032 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001033 ):
1034 pCount, cCount = op["operands"]
1035 # Force value of operand[1] to be within [0, num_bits]
1036 assert (
1037 pCount == 2 and cCount == 0
1038 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
1039
1040 placeholders = []
1041 for idx, shape in enumerate(shapeList[:]):
1042 if idx == 1:
1043 if dtypeList[idx] == DType.INT8:
1044 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
1045 elif dtypeList[idx] == DType.INT16:
1046 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
1047 elif dtypeList[idx] == DType.INT32:
1048 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
1049 elif error_name == ErrorIf.WrongInputType:
1050 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
1051 else:
1052 raise Exception("OpArithmeticRightShift: invalid input dtype")
1053 else:
1054 arr = testGen.getRandTensor(shape, dtypeList[idx])
1055 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
1056
1057 return placeholders
1058
1059 @staticmethod
Won Jeon64e4bfe2024-01-18 06:31:55 +00001060 def tvgReshape(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
1061 dtypeList[1] = DType.SHAPE
1062 shapeList[1] = [len(argsDict["new_shape"])]
1063 # Create a new list for the pre-generated data in argsDict["fixed_data"]
1064 argsDict["fixed_data"] = [None, argsDict["new_shape"]]
1065
1066 return TosaTensorValuesGen.tvgLazyGenDefault(
1067 testGen, op, dtypeList, shapeList, argsDict, error_name
1068 )
1069
1070 @staticmethod
1071 def tvgTile(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
1072 dtypeList[1] = DType.SHAPE
1073 shapeList[1] = [len(argsDict["multiples"])]
1074 argsDict["fixed_data"] = [None, argsDict["multiples"]]
1075
1076 return TosaTensorValuesGen.tvgLazyGenDefault(
1077 testGen, op, dtypeList, shapeList, argsDict, error_name
1078 )
1079
1080 @staticmethod
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001081 def tvgSelect(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001082 # Set datatype of condition tensor to boolean
1083 dtypeList[0] = DType.BOOL
1084
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001085 return TosaTensorValuesGen.tvgLazyGenDefault(
1086 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001087 )
1088
1089 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001090 def tvgIntDiv(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001091 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001092 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001093 pCount, cCount = op["operands"]
1094 assert (
1095 pCount == 2 and cCount == 0
1096 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1097
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001098 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001099
1100 # Two invalid cases for Op.INTDIV:
1101 # 1. divisor == 0
1102 # 2. dividend == -(1<<31) and divisor == -1
1103 while True:
1104 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1105 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
1106
1107 if (divisor_arr == 0).any():
1108 continue
1109
1110 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1111 continue
1112
1113 break
1114
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001115 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001116 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1117 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001118 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001119 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1120 )
1121
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001122 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001123 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001124 return TosaTensorValuesGen.tvgLazyGenDefault(
1125 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001126 )
1127
Jeremy Johnson30476252023-11-20 16:15:30 +00001128 # Set the MUL data range to the square root of the largest value
1129 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001130 TVG_FLOAT_HIGH_VALUE_MUL = {
1131 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1132 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1133 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1134 }
1135
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001136 @staticmethod
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001137 def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1138 if error_name is not None or dtypeList[0] in (
1139 DType.FP16,
1140 DType.BF16,
1141 DType.FP32,
1142 ):
1143 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001144 data_range = TosaTensorValuesGen._get_data_range(
1145 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
1146 )
1147 if data_range:
1148 argsDict["data_range"] = data_range
1149
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001150 return TosaTensorValuesGen.tvgLazyGenDefault(
1151 testGen, opName, dtypeList, shapeList, argsDict, error_name
1152 )
1153 else:
1154 # Integer test
1155 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001156 pCount, cCount = op["operands"]
1157 assert (
1158 pCount == 2 and cCount == 0
1159 ), "Op.MUL must have 2 placeholders, 0 consts"
1160
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001161 tens_ser_list = []
1162
1163 # Make sure multiply result in int32 range
Won Jeon74342e52024-01-09 00:34:40 +00001164 if dtypeList[0] == DType.SHAPE:
1165 shift = 0
1166 else:
1167 shift = argsDict["shift"]
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001168 if dtypeList[0] == DType.INT8:
1169 num_bits = 8
1170 elif dtypeList[0] == DType.INT16:
1171 num_bits = 16
Won Jeon74342e52024-01-09 00:34:40 +00001172 elif dtypeList[0] in (DType.INT32, DType.SHAPE):
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001173 num_bits = 32
1174 elif error_name == ErrorIf.WrongInputType:
1175 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001176 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001177 raise Exception("OpMul: invalid input dtype")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001178
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001179 for idx, shape in enumerate(shapeList[:]):
Won Jeon74342e52024-01-09 00:34:40 +00001180 if dtypeList[idx] == DType.SHAPE:
1181 low = testGen.args.tensor_shape_range[0]
1182 high = testGen.args.tensor_shape_range[1]
1183 else:
1184 low = -(2 ** (num_bits - 1))
1185 high = (2 ** (num_bits - 1)) - 1
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001186
1187 a_arr = np.int32(
1188 testGen.rng.integers(low=low, high=high, size=shapeList[0])
1189 )
1190 b_arr = np.int32(
1191 testGen.rng.integers(low=low, high=high, size=shapeList[1])
1192 )
1193
1194 i = 0
1195 while True:
1196
1197 a_arr_64 = a_arr.astype(np.int64)
1198 b_arr_64 = b_arr.astype(np.int64)
1199
1200 if shift > 0:
1201 rounding = 1 << (shift - 1)
1202 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001203 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001204 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001205
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001206 if (result_arr > -(2**31)).all() and (
1207 result_arr <= ((2**31) - 1)
1208 ).all():
1209 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001210
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001211 i = i + 1
1212 a_arr = a_arr // 2
1213 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001214
Won Jeon74342e52024-01-09 00:34:40 +00001215 if dtypeList[0] == DType.SHAPE:
1216 tens_ser_list.append(
1217 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
1218 )
1219 tens_ser_list.append(
1220 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
1221 )
1222 else:
1223 tens_ser_list.append(
1224 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1225 )
1226 tens_ser_list.append(
1227 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1228 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001229
1230 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001231
1232 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001233 def tvgConcat(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001234 count = len(shapeList) - testGen.args.num_const_inputs_concat
1235 if count < 1:
1236 count = 1
1237 if testGen.args.num_const_inputs_concat == 0:
1238 count = len(shapeList)
1239
Won Jeon74342e52024-01-09 00:34:40 +00001240 op = testGen.TOSA_OP_LIST[opName]
1241 if op["op"] == Op.CONCAT_SHAPE:
1242 # Set the axis to 0
1243 shapeList = TosaTensorGen.tgConcatConstInput(
1244 testGen, shapeList, 0, error_name
1245 )
1246 else:
1247 shapeList = TosaTensorGen.tgConcatConstInput(
1248 testGen, shapeList, argsDict["axis"], error_name
1249 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001250
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001251 # Override default pCount/cCount for operator
1252 argsDict["p_count"] = count
1253 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001254
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001255 return TosaTensorValuesGen.tvgLazyGenDefault(
1256 testGen, opName, dtypeList, shapeList, argsDict, error_name
1257 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001258
1259 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001260 def tvgLogicalShift(
1261 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1262 ):
1263 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001264 pCount, cCount = op["operands"]
1265 assert (
1266 pCount == 2 and cCount == 0
1267 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
1268 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1269 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001270 tens_ser_list = []
1271 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001272 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1273 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001274 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001275 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1276 )
1277
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001278 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001279
1280 @staticmethod
Jeremy Johnsona0150012023-11-15 15:52:06 +00001281 def tvgEqual(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1282 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1283 # Integer
1284 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001285 pCount, cCount = op["operands"]
1286 assert (
1287 pCount == 2 and cCount == 0
1288 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001289
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001290 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1291 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001292
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001293 # Using random numbers means that it will be very unlikely that
1294 # there are any matching (equal) values, therefore force that
1295 # there are twice the number of matching values as the tensor rank
1296 for num in range(0, len(shapeList[0]) * 2):
1297 a_index = []
1298 b_index = []
1299 # Choose an index in each axis for the whole shape
1300 for axis in range(0, len(shapeList[0])):
1301 # Index can be up to the largest dimension in both shapes
1302 index = np.int32(
1303 testGen.rng.integers(
1304 0, max(shapeList[0][axis], shapeList[1][axis])
1305 )
1306 )
1307 # Reduce the index down to a shape's dim for broadcasting
1308 a_index.append(min(shapeList[0][axis] - 1, index))
1309 b_index.append(min(shapeList[1][axis] - 1, index))
1310
1311 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1312
Jeremy Johnsona0150012023-11-15 15:52:06 +00001313 tens_ser_list = []
1314 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001315 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1316 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001317 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001318 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1319 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001320 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001321 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001322 # ERROR_IF or floating point test
1323 return TosaTensorValuesGen.tvgLazyGenDefault(
1324 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001325 )
1326
1327 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001328 def tvgReduceSum(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001329 dtype = dtypeList[0]
1330 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001331 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001332 pCount, cCount = op["operands"]
1333 assert (
1334 pCount == 1 and cCount == 0
1335 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1336 # Limit values so that the sum cannot exceed the range of an int32 during
1337 # summation of any axis
1338 range_val = int((1 << 31) / max(shapeList[0]))
1339 values_arr = np.int32(
1340 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1341 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001342 tens_ser_list = []
1343 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001344 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001345 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001346 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001347 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001348 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001349 if (
1350 error_name is None
1351 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1352 ):
1353 # Limit ranges for (non error & non compliance) tests by using
1354 # values that can be summed on any axis to not hit infinity
1355 highval_lookup = {
1356 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1357 / max(shapeList[0])
1358 }
1359 data_range = TosaTensorValuesGen._get_data_range(
1360 testGen, dtype, highval_lookup
1361 )
1362 assert data_range is not None
1363 argsDict["data_range"] = data_range
1364
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001365 return TosaTensorValuesGen.tvgLazyGenDefault(
1366 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001367 )
1368
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001369 @staticmethod
1370 def tvgReduceProduct(
1371 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1372 ):
1373 dtype = dtypeList[0]
1374 if error_name is None:
1375 # Limit ranges for (non error) tests by using
1376 # values that can be multiplied on any axis to not hit infinity
1377 highval_lookup = {
1378 dtype: math.pow(
1379 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1380 1 / max(shapeList[0]),
1381 )
1382 }
1383 data_range = TosaTensorValuesGen._get_data_range(
1384 testGen, dtype, highval_lookup
1385 )
1386 assert data_range is not None
1387 argsDict["data_range"] = data_range
1388
1389 return TosaTensorValuesGen.tvgLazyGenDefault(
1390 testGen, opName, dtypeList, shapeList, argsDict, error_name
1391 )
1392
Jeremy Johnson30476252023-11-20 16:15:30 +00001393 # Set the POW exponent high data range
1394 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1395 DType.FP32: 10.0,
1396 DType.FP16: 10.0,
1397 DType.BF16: 10.0,
1398 }
1399 # POW highest base value (within a safe margin of error) that can be raised
1400 # to +ve exponent that doesn't become Infinity
1401 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1402 DType.FP32: math.floor(
1403 math.pow(
1404 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1405 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1406 )
1407 ),
1408 DType.FP16: math.floor(
1409 math.pow(
1410 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1411 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1412 )
1413 ),
1414 DType.BF16: math.floor(
1415 math.pow(
1416 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1417 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1418 )
1419 ),
1420 }
1421 # POW lowest base value (within a safe margin of error) that can be raised
1422 # to -ve exponent that doesn't become Infinity
1423 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1424 DType.FP32: math.ceil(
1425 math.pow(
1426 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1427 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1428 )
1429 * 1000
1430 )
1431 / 1000,
1432 DType.FP16: math.ceil(
1433 math.pow(
1434 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1435 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1436 )
1437 * 1000
1438 )
1439 / 1000,
1440 DType.BF16: math.ceil(
1441 math.pow(
1442 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1443 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1444 )
1445 * 1000
1446 )
1447 / 1000,
1448 }
1449
1450 @staticmethod
1451 def tvgPow(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1452 if error_name is not None:
1453 return TosaTensorValuesGen.tvgLazyGenDefault(
1454 testGen, opName, dtypeList, shapeList, argsDict, error_name
1455 )
1456 dtype = dtypeList[0]
1457 # Different ranges for POW
1458 test_set = argsDict["s"]
1459 if test_set == 0:
1460 # Positive base with fractional exponent
1461 base_range = TosaTensorValuesGen._get_data_range(
1462 testGen,
1463 dtype,
1464 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1465 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1466 )
1467 exp_range = TosaTensorValuesGen._get_data_range(
1468 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1469 )
1470 exp_round = False
1471 else:
1472 # Integer exponent
1473 exp_range = TosaTensorValuesGen._get_data_range(
1474 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1475 )
1476 exp_round = True
1477 if test_set == 1:
1478 # Positive base
1479 base_range = TosaTensorValuesGen._get_data_range(
1480 testGen,
1481 dtype,
1482 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1483 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1484 )
1485 else:
1486 assert test_set == 2
1487 # Negative base
1488 # Supply new look up tables with negative values
1489 base_range = TosaTensorValuesGen._get_data_range(
1490 testGen,
1491 dtype,
1492 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1493 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1494 )
1495
1496 data_range_list = (
1497 {
1498 "range": base_range,
1499 },
1500 {
1501 "range": exp_range,
1502 "round": exp_round,
1503 },
1504 )
1505 argsDict["data_range_list"] = data_range_list
1506 return TosaTensorValuesGen.tvgLazyGenDefault(
1507 testGen, opName, dtypeList, shapeList, argsDict, error_name
1508 )
1509
1510 @staticmethod
1511 def tvgLogRsqrt(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1512 # LOG & RSQRT data range from lowest expressible positive number to
1513 # largest to avoid NaNs
1514 data_range = TosaTensorValuesGen._get_data_range(
1515 testGen,
1516 dtypeList[0],
1517 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1518 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1519 )
1520 if data_range:
1521 argsDict["data_range"] = data_range
1522
1523 return TosaTensorValuesGen.tvgLazyGenDefault(
1524 testGen, opName, dtypeList, shapeList, argsDict, error_name
1525 )
1526
1527 # Set the EXP data range to the log of the largest to smallest values
1528 # to avoid infinities or making the result zero
1529 TVG_FLOAT_HIGH_VALUE_EXP = {
1530 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1531 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1532 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1533 }
1534 TVG_FLOAT_LOW_VALUE_EXP = {
1535 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1536 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1537 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1538 }
1539
1540 @staticmethod
1541 def tvgExp(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1542 data_range = TosaTensorValuesGen._get_data_range(
1543 testGen,
1544 dtypeList[0],
1545 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1546 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1547 )
1548 if data_range:
1549 argsDict["data_range"] = data_range
1550
1551 return TosaTensorValuesGen.tvgLazyGenDefault(
1552 testGen, opName, dtypeList, shapeList, argsDict, error_name
1553 )
1554
1555 @staticmethod
1556 def tvgFullyConnected(
1557 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1558 ):
1559 dtype = dtypeList[0]
1560 if (
1561 error_name is None
1562 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001563 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001564 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001565 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001566 # Limit ranges for (non error & non compliance) FP tests by using
1567 # values that can be multiplied on any axis to not hit infinity/NaN
1568 IC = shapeList[0][1]
1569 highval_lookup = {
1570 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1571 }
1572 data_range = TosaTensorValuesGen._get_data_range(
1573 testGen, dtype, highval_lookup
1574 )
1575 assert data_range is not None
1576 argsDict["data_range"] = data_range
1577
1578 return TosaTensorValuesGen.tvgLazyGenDefault(
1579 testGen, opName, dtypeList, shapeList, argsDict, error_name
1580 )
1581
Jeremy Johnson708da822023-11-15 16:25:45 +00001582 @staticmethod
1583 def tvgCast(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1584 in_dtype = dtypeList[0]
1585 out_dtype = argsDict["out_type"]
1586 # Create look up to limit input tensor to output type maximums to avoid
1587 # FP infinities and saturation of integers
1588 out_range = testGen.getDTypeRange(out_dtype, high_inclusive=True)
1589 highval_lookup = {in_dtype: out_range[1]}
1590 data_range = TosaTensorValuesGen._get_data_range(
1591 testGen,
1592 in_dtype,
1593 highval_lookup,
1594 )
1595
1596 assert data_range is not None
1597 argsDict["data_range"] = data_range
1598
1599 return TosaTensorValuesGen.tvgLazyGenDefault(
1600 testGen, opName, dtypeList, shapeList, argsDict, error_name
1601 )
1602
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001603 @staticmethod
1604 def tvgGather(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1605 K = shapeList[0][1]
1606
1607 # Fix the type of the indices tensor
1608 dtypeList[1] = DType.INT32
1609
1610 dtype = dtypeList[0]
1611 if not gtu.dtypeIsSupportedByCompliance(dtype):
1612 # Test unsupported by data generator
1613 op = testGen.TOSA_OP_LIST[opName]
1614 pCount, cCount = op["operands"]
1615 assert (
1616 pCount == 2 and cCount == 0
1617 ), "Op.GATHER must have 2 placeholders, 0 consts"
1618
1619 tens_ser_list = []
1620 for idx, shape in enumerate(shapeList):
1621 dtype = dtypeList[idx]
1622 if idx != 1:
1623 arr = testGen.getRandTensor(shape, dtype)
1624 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1625 else:
1626 # Limit data range of indices tensor upto K (exclusive)
1627 arr = testGen.getRandTensor(shape, dtype, (0, K))
1628 # To match old functionality - create indices as CONST
1629 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1630
1631 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1632
1633 else:
1634 # ERROR_IF or floating point test
1635 # Use inclusive values upto index K for indices tensor
1636 data_range_list = (
1637 {"range": None},
1638 {"range": (0, K - 1)},
1639 )
1640 argsDict["data_range_list"] = data_range_list
1641
1642 return TosaTensorValuesGen.tvgLazyGenDefault(
1643 testGen, opName, dtypeList, shapeList, argsDict, error_name
1644 )
1645
1646 @staticmethod
1647 def tvgScatter(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1648 K = shapeList[0][1]
1649 W = shapeList[2][1]
1650
1651 # Work out an indices tensor here with data that doesn't exceed the
1652 # dimension K of the values_in tensor and does NOT repeat the same K
1653 # location as needed by the spec:
1654 # "It is not permitted to repeat the same output index within a single
1655 # SCATTER operation and so each output index occurs at most once."
1656 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1657
1658 # Fix the type of the indices tensor
1659 dtypeList[1] = DType.INT32
1660
1661 dtype = dtypeList[0]
1662 if not gtu.dtypeIsSupportedByCompliance(dtype):
1663 # Test unsupported by data generator
1664 op = testGen.TOSA_OP_LIST[opName]
1665 pCount, cCount = op["operands"]
1666 assert (
1667 pCount == 3 and cCount == 0
1668 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1669
1670 tens_ser_list = []
1671 for idx, shape in enumerate(shapeList):
1672 dtype = dtypeList[idx]
1673 if idx != 1:
1674 arr = testGen.getRandTensor(shape, dtype)
1675 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1676 else:
1677 # Create the indices array
1678 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1679 arr = []
1680 for n in range(shape[0]):
1681 # Get a shuffled list of output indices (0 to K-1) and
1682 # limit length to W
1683 arr.append(testGen.rng.permutation(K)[:W])
1684 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1685 # To match old functionality - create indices as CONST
1686 tens_ser_list.append(
1687 testGen.ser.addConst(shape, dtype, indices_arr)
1688 )
1689
1690 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1691
1692 else:
1693 # ERROR_IF or floating point test
1694 # Use inclusive values upto index K for indices tensor
1695 data_range_list = (
1696 {"range": None},
1697 {"range": (0, K - 1)},
1698 {"range": None},
1699 )
1700 argsDict["data_range_list"] = data_range_list
1701
1702 return TosaTensorValuesGen.tvgLazyGenDefault(
1703 testGen, opName, dtypeList, shapeList, argsDict, error_name
1704 )
1705
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001706
1707class TosaArgGen:
1708 """Argument generators create exhaustive or random lists of attributes for
1709 operators that take attributes or other parameters.
1710
1711 The return value is a list of (descriptive_name, [arglist]) tuples where
1712 the descriptive_name is appended to the test name and the arglist is expanded
1713 as arguments to the operator build function.
1714 """
1715
1716 def __init__(self):
1717 pass
1718
1719 @staticmethod
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001720 def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001721 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001722 if (
1723 error_name is None
1724 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1725 and gtu.dtypeIsSupportedByCompliance(dtype)
1726 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001727 if dtype in [DType.FP16, DType.FP32, DType.BF16]:
1728 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1729 else:
1730 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1731 else:
1732 # Error test or No data generator types listed - assume random
1733 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1734
1735 # Expand arg list with other data generator types
1736 new_arg_list = []
1737 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001738 for arg_str, args_dict in arg_list:
1739 args_dict["dg_type"] = dg_type
Jeremy Johnson1271c442023-09-05 11:39:26 +01001740 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001741 if error_name is None:
1742 num_test_sets = (
1743 args_dict["num_test_sets"]
1744 if "num_test_sets" in args_dict
1745 else 0
1746 )
1747 else:
1748 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001749
1750 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1751 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001752 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001753 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001754 shape_info = (
1755 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1756 if "shape" in args_dict
1757 else ""
1758 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001759 print(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001760 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001761 )
1762 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001763 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001764 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001765 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001766
Jeremy Johnson30476252023-11-20 16:15:30 +00001767 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1768
1769 if num_test_sets > 0:
1770 for s in range(0, num_test_sets):
1771 new_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001772 new_args_dict = args_dict.copy()
1773 new_args_dict["s"] = s
1774 new_arg_list.append((new_arg_str, new_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001775 else:
1776 # Default is a single test
1777 new_arg_list.append((arg_str, args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001778
1779 return new_arg_list
1780
1781 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001782 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1783 """A trivial argument generator for operators that don't take any
1784 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001785 arg_list = TosaArgGen._add_data_generators(
1786 testGen,
1787 opName,
1788 dtype,
1789 [("", {})],
1790 error_name,
1791 )
1792 # Return list of tuples: (arg_str, args_dict)
1793 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001794
1795 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +00001796 def agPow(testGen, opName, shapeList, dtype, error_name=None):
1797 """Pow operator needs different test sets to cover random numbers
1798 without creating NaNs or Infs"""
1799 arg_list = TosaArgGen._add_data_generators(
1800 testGen,
1801 opName,
1802 dtype,
1803 [("", {"num_test_sets": 3})],
1804 error_name,
1805 )
1806 # Return list of tuples: (arg_str, args_dict)
1807 return arg_list
1808
1809 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001810 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1811 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001812 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001813 shape = shapeList[0]
1814
1815 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001816 # Set too small axis
1817 axes = [testGen.rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001818 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001819 # Set too large axis
1820 axes = [testGen.rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001821 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001822 # Create tests for each dimension
1823 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001824
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001825 opid = testGen.TOSA_OP_LIST[opName]["op"]
1826
1827 for a in axes:
1828 args_dict = {"axis": int(a)}
1829 if opid == Op.REDUCE_SUM:
1830 args_dict["dot_products"] = gtu.product(shape)
1831 args_dict["shape"] = shape
1832 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1833 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1834
1835 arg_list.append(("axis{}".format(a), args_dict))
1836
1837 arg_list = TosaArgGen._add_data_generators(
1838 testGen,
1839 opName,
1840 dtype,
1841 arg_list,
1842 error_name,
1843 )
1844 # Return list of tuples: (arg_str, args_dict)
1845 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001846
1847 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001848 def _calculate_sparsity(num_tests, sparsity_factor):
1849 sparsity = num_tests // sparsity_factor + 1
1850 # If there are only a small number of tests, just select them all
1851 if sparsity < 13:
1852 sparsity = 1
1853 # To get a variety of parameter combinations sparsity should not be a
1854 # multiple of 2, 3 or 5
1855 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1856 sparsity += 1
1857 return sparsity
1858
1859 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001860 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001861 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001862 arg_list = []
1863
Jeremy Johnson0c716862023-04-13 17:18:19 +01001864 if testGen.args.level8k and error_name is not None:
1865 # Don't produce negative large tests
1866 return arg_list
1867
1868 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001869 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001870 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001871 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001872
Jeremy Johnson1271c442023-09-05 11:39:26 +01001873 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001874
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001875 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01001876 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001877 depthwise = opName.startswith("depthwise")
1878
1879 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001880 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001881 if error_name != ErrorIf.WrongRank:
1882 assert len(ifm_shape) == rank
1883 assert len(filter_shape) == rank
1884
Jeremy Johnson0c716862023-04-13 17:18:19 +01001885 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001886 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001887 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01001888 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001889 # compliance size - KS
1890 k_size = gtu.product(k_shape)
1891 if not depthwise:
1892 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001893
Jeremy Johnson0c716862023-04-13 17:18:19 +01001894 if not testGen.args.level8k:
1895 # Generate comprehensive argument lists
1896 # - except for named errors, which use specific invalid value(s)
1897 if error_name == ErrorIf.PadSmallerZero:
1898 p_vals = [testGen.rng.choice(range(-5, 0))]
1899 else:
1900 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1901 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1902 if error_name == ErrorIf.StrideSmallerOne:
1903 # Can't use stride=0, as it is used to derive output shape, as a divisor
1904 s_vals = [testGen.rng.choice(range(-5, 0))]
1905 else:
1906 # Stride must be greater than 1 to force non-integer error
1907 startStride = (
1908 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001909 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001910 s_vals = [
1911 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1912 ]
1913 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1914 if error_name == ErrorIf.DilationSmallerOne:
1915 d_vals = [testGen.rng.choice(range(-5, 1))]
1916 else:
1917 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1918 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001919
Jeremy Johnson0c716862023-04-13 17:18:19 +01001920 if not error_name and testGen.args.oversize:
1921 # add some oversize argument values
1922 if max(ifm_shape) < 64:
1923 bigPadding = 9
1924 paddings.update(
1925 {
1926 x
1927 for x in itertools.product(
1928 *([[0, bigPadding]] * (k_rank * 2))
1929 )
1930 }
1931 )
1932 bigStride = 8
1933 strides.update(
1934 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1935 )
1936 bigDilation = 7
1937 dilations.update(
1938 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1939 )
1940 max_dim_size = None
1941
1942 # There are too many parameter combinations, so generate them sparsely,
1943 # very sparse for negative tests
1944 sparsity_factor = 2 if error_name else 120
1945 sparsity = TosaArgGen._calculate_sparsity(
1946 len(paddings) * len(strides) * len(dilations), sparsity_factor
1947 )
1948 else:
1949 # Only test 8k levels boundaries
1950 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1951 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1952 bigPadding = bigKernel
1953
1954 dilation_shape = [1] * k_rank
1955 pad_shape = [0] * k_rank * 2
1956 if conv3d:
1957 # Small stride apart from for big kernel (see below) to keep
1958 # tensor size/calculation small
1959 stride_shape = [1] * k_rank
1960 for idx in range(k_rank):
1961 pad_offset = idx * 2
1962 if k_shape[idx] == bigKernel:
1963 # Padding shape needs to account for tensor shape
1964 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1965 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1966 # Big stride to reduce output size
1967 stride_shape[idx] = bigKernel
1968 else:
1969 # Account for kernel size
1970 pad_shape[pad_offset] = k_shape[idx] - 1
1971 else:
1972 # Always have a large stride with extra padding and dilation to keep
1973 # tensor calculation reasonable
1974 stride_shape = [bigKernel] * k_rank
1975 for idx in range(k_rank):
1976 # Dilation shape must account for kernel size
1977 dilation_shape[idx] = bigKernel // k_shape[idx]
1978 # Padding shape needs to accommodate tensor/kernel & dilation
1979 pad_offset = idx * 2
1980 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1981 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1982
1983 strides = {tuple(stride_shape)}
1984 dilations = {tuple(dilation_shape)}
1985 paddings = {tuple(pad_shape)}
1986 # Create a limit for the output dimensions size
1987 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1988
1989 # Currently allow all combinations that are reasonable size
1990 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001991
1992 n = 0
1993 for s in sorted(list(strides)):
1994 for p in sorted(list(paddings)):
1995 for d in sorted(list(dilations)):
1996 if (
1997 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001998 # the padded shape must exceed the dilation * kernel to get a positive
1999 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002000 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
2001 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01002002 and (
2003 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01002004 or (
2005 (ifm_shape[3] - 1 + p[4] + p[5])
2006 > d[2] * (k_shape[2] - 1)
2007 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01002008 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002009 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002010 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01002011 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002012 for index in range(k_rank):
2013 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002014 partial = (
2015 ifm_shape[index + 1]
2016 - 1
2017 + p[pad_offset]
2018 + p[pad_offset + 1]
2019 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002020 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002021 remainders.append(partial % s[index])
2022 outputs.append((partial // s[index]) + 1)
2023
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002024 if (
2025 # the parameters must produce integer exact output
2026 error_name != ErrorIf.ConvOutputShapeNonInteger
2027 and max(remainders) == 0
2028 ) or (
2029 error_name == ErrorIf.ConvOutputShapeNonInteger
2030 and max(remainders) > 0
2031 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002032 if (
2033 max_dim_size is not None
2034 and max(outputs) >= max_dim_size
2035 ):
2036 # Test will consume too much memory - skip it
2037 continue
2038
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002039 # Compliance - number of dot product calculations
2040 if depthwise:
2041 # TODO - add support
2042 dots = 0
2043 else:
2044 dots = gtu.product(
2045 (ifm_shape[0], *outputs, filter_shape[0])
2046 )
2047 args_dict = {
2048 "acc_type": accum_dtype,
2049 "stride": s,
2050 "pad": p,
2051 "dilation": d,
2052 "kernel": k_shape,
2053 "ks": k_size,
2054 "dot_products": dots,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002055 "shape": ifm_shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002056 }
2057
Jeremy Johnson0c716862023-04-13 17:18:19 +01002058 # Support for larger values than 9 needs different delimiter
2059 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002060 arg_list.append(
2061 (
James Ward8b390432022-08-12 20:48:56 +01002062 "acc{}_st{}_pad{}_dilat{}".format(
2063 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002064 delim.join([str(x) for x in s]),
2065 delim.join([str(x) for x in p]),
2066 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002067 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002068 args_dict,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002069 )
2070 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002071 n += 1
2072
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01002073 arg_list = TosaArgGen._add_data_generators(
2074 testGen,
2075 opName,
2076 dtypes[0],
2077 arg_list,
2078 error_name,
2079 )
2080 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002081 return arg_list
2082
2083 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01002084 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
2085
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002086 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002087 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002088
2089 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002090 accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002091 elif error_name == ErrorIf.WrongInputType:
2092 # Pick some potentially correct output dtype if input type is incorrect
2093 accum_dtype = DType.INT32
2094 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002095 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002096
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002097 # Set up compliance info
2098 args_dict = {
2099 "acc_type": accum_dtype,
2100 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2101 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2102 "shape": shapeList[0],
2103 }
2104
2105 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2106
2107 arg_list = TosaArgGen._add_data_generators(
2108 testGen,
2109 opName,
2110 input_dtype,
2111 arg_list,
2112 error_name,
2113 )
2114 # Return list of tuples: (arg_str, args_dict)
2115 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002116
2117 @staticmethod
2118 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
2119 # Get valid accumulate type(s)
2120 if dtype == DType.INT8:
2121 accum_dtypes = [DType.INT32]
2122 elif dtype == DType.INT16:
2123 accum_dtypes = [DType.INT48]
2124 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002125 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002126 elif dtype == DType.BF16:
2127 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002128 elif dtype == DType.FP32:
2129 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01002130 elif error_name is None:
2131 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2132
2133 if error_name == ErrorIf.WrongOutputType:
2134 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson1271c442023-09-05 11:39:26 +01002135 accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002136 elif error_name == ErrorIf.WrongInputType:
2137 # Pick some potentially correct output dtype if input type is incorrect
2138 accum_dtypes = [DType.INT32]
2139
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002140 # Set up compliance info
2141 args_dict = {
2142 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2143 # Set dot_products = N*H*W
2144 "dot_products": gtu.product(
2145 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2146 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002147 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002148 }
2149
2150 # Create arg tuple of string and dict
2151 arg_list = []
2152 for a in accum_dtypes:
2153 d = args_dict.copy()
2154 d["acc_type"] = a
2155 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002156
2157 arg_list = TosaArgGen._add_data_generators(
2158 testGen,
2159 opName,
2160 dtype,
2161 arg_list,
2162 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002163 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002164 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002165 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002166
2167 @staticmethod
2168 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002169 arg_list = []
2170
Jeremy Johnson0c716862023-04-13 17:18:19 +01002171 if testGen.args.level8k and error_name is not None:
2172 # Don't produce negative large tests
2173 return arg_list
2174
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002175 ifm_shape = shapeList[0]
2176 filter_shape = shapeList[1]
2177
Jeremy Johnson1271c442023-09-05 11:39:26 +01002178 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002179
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002180 # Must be rank 4
2181 if error_name != ErrorIf.WrongRank:
2182 assert len(ifm_shape) == 4
2183 assert len(filter_shape) == 4
2184
Jeremy Johnson0c716862023-04-13 17:18:19 +01002185 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002186
Jeremy Johnson0c716862023-04-13 17:18:19 +01002187 if not testGen.args.level8k:
2188 # Generate comprehensive argument lists
2189 # - except for named errors, which use specific invalid value(s)
2190 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2191 if error_name == ErrorIf.PadLargerEqualKernel:
2192 max_filter_size = -max(k_shape[0], k_shape[1])
2193 p_vals = [
2194 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
2195 ]
2196 else:
2197 p_vals = [
2198 x
2199 for x in range(
2200 smallest_padding_size, testGen.args.max_conv_padding + 1
2201 )
2202 ]
2203 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2204 if error_name == ErrorIf.StrideSmallerOne:
2205 # Can't use stride=0, as it is used to derive output shape, as a divisor
2206 s_vals = [testGen.rng.choice(range(-5, 0))]
2207 else:
2208 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2209 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002210
Jeremy Johnson0c716862023-04-13 17:18:19 +01002211 if not error_name and testGen.args.oversize:
2212 # add some oversize argument values
2213 if max(ifm_shape) < 64:
2214 bigPadding = 9
2215 paddings.update(
2216 {
2217 x
2218 for x in itertools.product(
2219 *([[smallest_padding_size, bigPadding]] * 4)
2220 )
2221 }
2222 )
2223 bigStride = 8
2224 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2225
2226 # There are too many parameter combinations, so generate them sparsely,
2227 # very sparse for negative tests
2228 sparsity_factor = 2 if error_name else 10
2229 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2230 # If there are only a small number of tests, just select them all
2231 if sparsity < 13:
2232 sparsity = 1
2233 # To get a variety of parameter combinations sparsity should not be a
2234 # multiple of 2, 3 or 5
2235 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2236 sparsity += 1
2237 else:
2238 # Only test 8k levels boundaries
2239 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2240 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2241 bigPadding = bigKernel
2242
2243 pad_shape = [0] * (len(k_shape) * 2)
2244 stride_shape = [1] * len(k_shape)
2245 # The point at which input dimension combined with the stride will
2246 # create large output sizes!
2247 LARGE_SIZE = 2
2248 for idx in range(len(k_shape)):
2249 pad_offset = idx * 2
2250 if k_shape[idx] == bigKernel:
2251 # Set large stride
2252 stride_shape[idx] = bigKernel
2253 # Use negative output padding to reduce shape size
2254 pad_shape[pad_offset] = -(bigPadding - 1)
2255 if ifm_shape[idx + 1] > LARGE_SIZE:
2256 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2257 else:
2258 # The other dimension should be the bigKernel
2259 alt_idx = 1 - idx
2260 if (
2261 k_shape[alt_idx] == bigKernel
2262 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2263 ):
2264 # As the input is small, the large stride won't
2265 # affect the output so we can add some padding
2266 pad_shape[pad_offset + 1] = bigPadding
2267
2268 strides = {tuple(stride_shape)}
2269 paddings = {tuple(pad_shape)}
2270
2271 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002272 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002273
2274 n = 0
2275 for s in sorted(list(strides)):
2276 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07002277 if n % sparsity == 0:
2278 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002279 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2280 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07002281 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002282
2283 # Support for larger values than 9 needs different delimiter
2284 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07002285 arg_list.append(
2286 (
James Ward8b390432022-08-12 20:48:56 +01002287 "acc{}_st{}_pad{}_os{}".format(
2288 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002289 delim.join([str(x) for x in s]),
2290 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07002291 "x".join([str(x) for x in os]),
2292 ),
James Ward8b390432022-08-12 20:48:56 +01002293 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002294 )
TatWai Chong24594f52022-06-08 00:48:04 -07002295 )
2296 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002297
2298 return arg_list
2299
2300 @staticmethod
2301 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002302 rank = len(shapeList[0])
2303
2304 # Exhaustively test combinations of padding on each side of each dimension
2305 # - the range of padding values is defined by pad_min and pad_max
2306 # - for padding >9, the name format needs to be more distinctive
2307 pad_min, pad_max = 0, 1
2308 pad_values = [x for x in range(pad_min, pad_max + 1)]
2309 if error_name == ErrorIf.PadSmallerZero:
2310 pad_values = [x for x in range(-2, 0)]
2311 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2312 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2313
2314 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
2315 pad_const_int = testGen.getRandNumberDType(dtype)
2316 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00002317 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002318 pad_const_int = 0
2319 pad_const_fp = testGen.getRandNumberDType(dtype)
2320 else:
2321 return []
2322
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002323 list_shape_pad_values = list(shape_pad_values)
2324 # If we are producing tests for rank 6 or greater use sparsity
2325 if len(list_shape_pad_values) > 1024:
2326 sparsity_factor = 2 if error_name else 120
2327 sparsity = TosaArgGen._calculate_sparsity(
2328 len(list_shape_pad_values), sparsity_factor
2329 )
2330 else:
2331 sparsity = 1
2332
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002333 # Build arg list
2334 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002335 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002336 paddings = list(paddings)
2337 args_valid = True
2338
2339 if error_name == ErrorIf.PadSmallerZero:
2340 # Prevent negative output shapes while ensuring still testing for negative padding
2341 for i in range(rank):
2342 dim_after_padding = (
2343 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2344 )
2345 if dim_after_padding < 1:
2346 paddings[i] = (0, 0)
2347 if all([p > -1 for p in paddings[i]]):
2348 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002349 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01002350 name = "pad"
2351 for r in range(rank):
2352 before, after = paddings[r]
2353 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002354 args_dict = {
2355 "pad": np.array(paddings),
2356 "pad_const_int": pad_const_int,
2357 "pad_const_fp": pad_const_fp,
2358 }
2359 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002360
2361 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
2362 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002363
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002364 arg_list = TosaArgGen._add_data_generators(
2365 testGen,
2366 opName,
2367 dtype,
2368 arg_list,
2369 error_name,
2370 )
2371
2372 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002373 return arg_list
2374
2375 @staticmethod
2376 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
2377 arg_list = []
2378
2379 shape = shapeList[0]
2380 if error_name != ErrorIf.WrongRank:
2381 assert len(shape) == 4
2382
Jeremy Johnson0c716862023-04-13 17:18:19 +01002383 test_level8k = testGen.args.level8k and error_name is None
2384
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002385 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002386 startKernel = 2
2387 startPad = 0
2388 if not test_level8k:
2389 # Generate comprehensive argument lists
2390 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2391 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2392 # Stride must be greater than 1 to force non-integer error
2393 s_vals = [
2394 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2395 ]
2396 strides = {x for x in itertools.product(*([s_vals] * 2))}
2397 k_vals = [
2398 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2399 ]
2400 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2401 max_dim_size = None
2402 else:
2403 # Only test 8k levels
2404 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2405 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2406 strides = {(1, bigStride), (bigStride, 4)}
2407 kernels = {(1, bigKernel), (bigKernel, 3)}
2408 paddings = set()
2409 for s in sorted(list(strides)):
2410 for k in sorted(list(kernels)):
2411 padding = []
2412 for idx in range(len(k)):
2413 total_padding = s[idx] - shape[idx + 1] + k[idx]
2414 while total_padding < 0:
2415 # Must meet: shape + padding > kernel
2416 total_padding += s[idx]
2417 if total_padding < k[idx]:
2418 padding.extend([0, total_padding])
2419 else:
2420 # Note this may produce padding >= k[idx] which is not
2421 # allowed - but will be ignored in the creation loop below
2422 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2423 paddings.add(tuple(padding))
2424 # Create a limit for the output dimensions size
2425 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002426
James Ward8b390432022-08-12 20:48:56 +01002427 if opName == "max_pool2d":
2428 accum_dtypes = [None] # max_pool has no accumulate dtype
2429 elif dtype == DType.INT8 or dtype == DType.INT16:
2430 accum_dtypes = [DType.INT32]
2431 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002432 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002433 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002434 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01002435 elif error_name is None:
2436 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2437 else:
2438 # Set to something for the ErrorIf case which has
2439 # incorrect input data-type
2440 accum_dtypes = [DType.INT32]
2441
Jeremy Johnson0c716862023-04-13 17:18:19 +01002442 if not test_level8k:
2443 if testGen.args.oversize:
2444 # add some oversize argument values
2445 bigStride = 7
2446 bigKernel = 9
2447 strides.update(
2448 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002449 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002450 kernels.update(
2451 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2452 )
2453 if max(shape) < 64:
2454 # padding must be less than the kernel size
2455 bigPadding = bigKernel - 1
2456 paddings.update(
2457 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2458 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002459
Jeremy Johnson0c716862023-04-13 17:18:19 +01002460 # There are too many parameter combinations, so generate them sparsely,
2461 # very sparse for negative tests
2462 sparsity_factor = 2 if error_name else 500
2463 sparsity = (
2464 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2465 )
2466 else:
2467 # We have already limited test output combinations for 8k tests
2468 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002469
James Ward8b390432022-08-12 20:48:56 +01002470 arg_str = (
2471 "acc{}_st{}_kern{}_pad{}"
2472 if accum_dtypes[0] is not None
2473 else "st{}_kern{}_pad{}"
2474 )
2475
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002476 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002477 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002478 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002479
2480 # Support for larger values than 9 needs different delimiter
2481 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002482 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002483 delim.join([str(x) for x in stride]),
2484 delim.join([str(x) for x in kern]),
2485 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002486 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002487 args_dict = {
2488 "stride": stride,
2489 "pad": pad,
2490 "kernel": kern,
2491 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002492 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002493 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2494 }
James Ward8b390432022-08-12 20:48:56 +01002495
2496 if accum is not None:
2497 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002498 args_dict["acc_type"] = accum
2499 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002500
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002501 n = 0
James Ward8b390432022-08-12 20:48:56 +01002502 for a in accum_dtypes:
2503 for s in sorted(list(strides)):
2504 for p in sorted(list(paddings)):
2505 for k in sorted(list(kernels)):
2506 if error_name in [
2507 ErrorIf.StrideSmallerOne,
2508 ErrorIf.KernelSmallerOne,
2509 ErrorIf.PadSmallerZero,
2510 ErrorIf.PadLargerEqualKernel,
2511 ]:
2512 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
2513 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002514 )
James Ward8b390432022-08-12 20:48:56 +01002515 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002516 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002517 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002518 )
James Ward8b390432022-08-12 20:48:56 +01002519 elif (
2520 n % sparsity == 0
2521 # padding must not exceed the kernel size
2522 and p[0] < k[0]
2523 and p[1] < k[0]
2524 and p[2] < k[1]
2525 and p[3] < k[1]
2526 # the padded shape must exceed the kernel size
2527 and (shape[1] + p[0] + p[1]) > k[0]
2528 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002529 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002530 partial_h = shape[1] + p[0] + p[1] - k[0]
2531 partial_w = shape[2] + p[2] + p[3] - k[1]
2532 remainder_h = partial_h % s[0]
2533 remainder_w = partial_w % s[1]
2534 output_h = partial_h // s[0] + 1
2535 output_w = partial_w // s[1] + 1
2536 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01002537 if (
2538 # the parameters must produce integer exact output
2539 error_name != ErrorIf.PoolingOutputShapeNonInteger
2540 and remainder_h == 0
2541 and remainder_w == 0
2542 ) or (
2543 error_name == ErrorIf.PoolingOutputShapeNonInteger
2544 and (remainder_h != 0 or remainder_w != 0)
2545 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002546 if (
2547 max_dim_size is not None
2548 and max(output_h, output_w) > max_dim_size
2549 ):
2550 # Test will consume too much memory - skip it
2551 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002552 # Dot products = N*OH*OW*C
2553 dp = gtu.product(
2554 (shape[0], output_h, output_w, shape[3])
2555 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002556 arg_list.append(
2557 get_arg_list_element(a, s, p, k, dp, shape)
2558 )
James Ward8b390432022-08-12 20:48:56 +01002559 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002560
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002561 # Now add data generator types
2562 arg_list = TosaArgGen._add_data_generators(
2563 testGen,
2564 opName,
2565 dtype,
2566 arg_list,
2567 error_name,
2568 )
2569
2570 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002571 return arg_list
2572
2573 @staticmethod
2574 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
2575 arg_list = []
2576
2577 # Enumerate the output types here
2578 if error_name == ErrorIf.WrongOutputType:
2579 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
2580 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002581 dtypeList = [
2582 DType.BOOL,
2583 DType.INT16,
2584 DType.INT32,
2585 DType.FP16,
2586 DType.BF16,
2587 DType.FP32,
2588 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002589 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002590 dtypeList = [
2591 DType.BOOL,
2592 DType.INT8,
2593 DType.INT32,
2594 DType.FP16,
2595 DType.BF16,
2596 DType.FP32,
2597 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002598 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002599 dtypeList = [
2600 DType.BOOL,
2601 DType.INT8,
2602 DType.INT16,
2603 DType.FP16,
2604 DType.BF16,
2605 DType.FP32,
2606 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002607 elif inDtype == DType.BOOL:
2608 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002609 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00002610 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002611 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00002612 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002613 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00002614 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002615 elif error_name == ErrorIf.WrongInputType:
2616 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002617 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002618 else:
2619 raise Exception("Unexpected input dtype: {}".format(inDtype))
2620
2621 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002622 arg_list.append(
2623 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2624 )
2625
2626 # Now add data generator types
2627 arg_list = TosaArgGen._add_data_generators(
2628 testGen,
2629 opName,
2630 dtype,
2631 arg_list,
2632 error_name,
2633 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002634
2635 return arg_list
2636
2637 @staticmethod
2638 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
2639 arg_list = []
2640
2641 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002642 for outDtype in [
2643 DType.UINT8,
2644 DType.INT8,
2645 DType.INT16,
2646 DType.INT32,
2647 DType.UINT16,
2648 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002649 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002650 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002651 and error_name == ErrorIf.OutputZeroPointNotZero
2652 ):
2653 continue
2654 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002655 outDtype != DType.UINT16
2656 and error_name == ErrorIf.U16OutputZeroPointNotValid
2657 ) or (
2658 inDtype != DType.UINT16
2659 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002660 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002661 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002662 continue
2663 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002664 inDtype == DType.UINT8
2665 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002666 and error_name != ErrorIf.WrongOutputType
2667 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002668 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2669 continue
2670 if (
2671 inDtype not in [DType.INT8, DType.INT16]
2672 and outDtype == DType.UINT8
2673 and error_name != ErrorIf.WrongOutputType
2674 ):
2675 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2676 continue
2677 if (
2678 inDtype == DType.UINT16
2679 and outDtype != DType.INT16
2680 and error_name != ErrorIf.WrongOutputType
2681 ):
2682 # The only output dtype for UINT16 is INT16, skip all others
2683 continue
2684 if (
2685 inDtype != DType.INT16
2686 and outDtype == DType.UINT16
2687 and error_name != ErrorIf.WrongOutputType
2688 ):
2689 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002690 continue
2691 if (
2692 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002693 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002694 ):
2695 continue
2696
2697 for scale32 in [False, True]:
2698 if error_name == ErrorIf.ScaleTrue and not scale32:
2699 continue
2700 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2701 continue
2702 for double_round in [False, True]:
2703 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2704 continue
2705 for per_channel in [False, True]:
2706
2707 if (
2708 inDtype == DType.INT48
2709 and scale32
2710 and error_name != ErrorIf.ScaleTrue
2711 ):
2712 # Illegal condition. Must be scale32=False
2713 continue
2714 if (
2715 double_round
2716 and not scale32
2717 and error_name != ErrorIf.ScaleNotTrue
2718 ):
2719 # Illegal condition. ERROR_IF(!scale32 && double_round)
2720 continue
2721
2722 arg_list.append(
2723 (
2724 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002725 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002726 int(scale32),
2727 int(double_round),
2728 int(per_channel),
2729 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002730 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002731 )
2732 )
2733
2734 return arg_list
2735
2736 @staticmethod
2737 def agMul(testGen, opName, shapeList, dtype, error_name=None):
2738 arg_list = []
2739
2740 if dtype is DType.INT32:
2741 for p in range(testGen.args.num_rand_permutations):
2742
2743 shift = testGen.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002744 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002745 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002746 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002747
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002748 arg_list = TosaArgGen._add_data_generators(
2749 testGen,
2750 opName,
2751 dtype,
2752 arg_list,
2753 error_name,
2754 )
2755 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002756 return arg_list
2757
2758 @staticmethod
2759 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
2760 arg_list = []
2761
2762 arg_list.append(("roundTrue", [True]))
2763 arg_list.append(("roundFalse", [False]))
2764
2765 return arg_list
2766
Luke Hutton57287132023-02-06 14:54:18 +00002767 @staticmethod
2768 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2769 arg_list = []
2770
2771 arg_list.append(("inverseTrue", [True]))
2772 arg_list.append(("inverseFalse", [False]))
2773
2774 return arg_list
2775
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002776 # Helper function for reshape. Gets some factors of a larger number.
2777 @staticmethod
2778 def getFactors(val, start=1):
2779 factors = []
2780
2781 for i in range(start, int(np.sqrt(val)) + 1):
2782 if (val % i) == 0:
2783 factors.append(i)
2784
2785 return factors
2786
2787 @staticmethod
2788 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
2789 arg_list = []
2790
2791 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002792 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002793 factors = TosaArgGen.getFactors(totalElements)
2794
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002795 # Find new shapes up to the number of permutations asked for
2796 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002797 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002798 # Rank from 1 to TOSA_TENSOR_MAX_RANK
2799 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002800 if len(factors) < newRank:
2801 continue
2802
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002803 # escape_counter limits the generation of new shapes to a reasonable time
2804 for escape_counter in range(100):
2805
2806 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002807 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002808 remainingElements = totalElements
2809 shuffledFactors = testGen.rng.permutation(factors)
2810 for i in range(1, newRank):
2811 # pick rank-1 factors
2812 newShape.append(shuffledFactors[0])
2813 remainingElements = remainingElements // shuffledFactors[0]
2814 shuffledFactors = testGen.rng.permutation(
2815 TosaArgGen.getFactors(remainingElements)
2816 )
2817 newShape.append(remainingElements)
2818
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002819 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002820 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002821 for name, args_dict in arg_list:
2822 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002823 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002824 break
2825
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002826 if not duplicate:
2827 outShape = "x".join([str(x) for x in newShape])
2828 arg_list.append(
2829 (
2830 "perm{}_rank{}_out{}".format(p, newRank, outShape),
2831 {"new_shape": newShape},
2832 )
2833 )
2834 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002835 break
2836
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002837 # Now add data generator types
2838 arg_list = TosaArgGen._add_data_generators(
2839 testGen,
2840 opName,
2841 dtype,
2842 arg_list,
2843 error_name,
2844 )
2845
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002846 return arg_list
2847
2848 @staticmethod
2849 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
2850 arg_list = []
2851
2852 ifm_shape = shapeList[0]
2853
2854 if error_name == ErrorIf.IndexOutsideBounds:
2855 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
2856 incorrect_small_index = range(-len(ifm_shape), 0)
2857 permutations = [p for p in itertools.permutations(incorrect_large_index)]
2858 permutations.extend(
2859 [p for p in itertools.permutations(incorrect_small_index)]
2860 )
2861 elif error_name == ErrorIf.IndexUsedTwice:
2862 # Create list with a duplicated index
2863 perm_range = list(range(len(ifm_shape)))
2864 index_choice = testGen.rng.choice(range(len(perm_range)))
2865 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
2866 permutations = [p for p in itertools.permutations(perm_range)]
2867
2868 else:
2869 # Get all permutations
2870 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
2871
2872 # Limit to possible permutations from shape dimension or argument setting
2873 limit = min(len(permutations), testGen.args.num_rand_permutations)
2874
2875 # Get random permutation generator that uses all permutations
2876 random_permutations = testGen.rng.permutation(permutations)
2877
2878 # Create list of required amount of permutations
2879 arg_list = [
2880 ("perm{}".format(p), [random_permutations[p].tolist()])
2881 for p in range(limit)
2882 ]
2883 return arg_list
2884
2885 @staticmethod
2886 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
2887 arg_list = []
2888
2889 ifm_shape = shapeList[0]
2890 rank = len(ifm_shape)
2891
2892 for p in range(testGen.args.num_rand_permutations):
2893 start = []
2894 size = []
2895
2896 valid = True
2897
2898 for i in range(rank):
2899 if ifm_shape[i] > 1:
2900 start.append(testGen.randInt(0, ifm_shape[i]))
2901 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
2902
2903 # Invalid slice size?
2904 if size[i] == 0:
2905 valid = False
2906 else:
2907 start.append(0)
2908 size.append(1)
2909
2910 if valid:
2911 # If ERROR_IF test required then incorrect start, size will be returned
2912 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
2913 testGen, error_name, ifm_shape, start, size
2914 )
evacha017f7d4252024-01-24 12:08:09 +00002915 arg_list.append(("perm{}".format(p), {"start": start, "size": size}))
2916 # Now add data generator types
2917 arg_list = TosaArgGen._add_data_generators(
2918 testGen,
2919 opName,
2920 dtype,
2921 arg_list,
2922 error_name,
2923 )
2924 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002925 return arg_list
2926
2927 @staticmethod
2928 def agTile(testGen, opName, shapeList, dtype, error_name=None):
2929 arg_list = []
2930
2931 ifm_shape = shapeList[0]
2932 rank = len(ifm_shape)
2933
2934 for p in range(testGen.args.num_rand_permutations):
2935
2936 # Pick a few random, but small multiple values
2937 # because otherwise this has a tendency to generate
2938 # enormous tensors
2939 multiples = []
2940 for i in range(rank):
2941 if ifm_shape[i] > 1000:
2942 # Multiple of 1 if ifm_shape dimension is large to reduce
2943 # tensor size
2944 multiples.append(1)
2945 elif max(ifm_shape) > 1000:
2946 multiples.append(2)
2947 else:
2948 multiples.append(testGen.randInt(1, 4))
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00002949 arg_list.append(("perm{}".format(p), {"multiples": multiples}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002950
Jeremy Johnson9f5febe2024-01-15 15:12:17 +00002951 # Now add data generator types
2952 arg_list = TosaArgGen._add_data_generators(
2953 testGen,
2954 opName,
2955 dtype,
2956 arg_list,
2957 error_name,
2958 )
2959 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002960 return arg_list
2961
2962 @staticmethod
2963 def agResize(testGen, opName, shapeList, dtype, error_name=None):
2964 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002965 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002966
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002967 def get_aspect_ratio_resize_params():
2968 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
2969 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
2970 invert = testGen.rng.choice((False, True))
2971 letterbox = testGen.rng.choice((False, True))
2972
2973 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
2974 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
2975 scale_y_d = scale_x_d = 1
2976 offset_x = offset_y = 0
2977
2978 if letterbox:
2979 max_border = scale_y_n
2980 border_y = testGen.randInt(low=0, high=max_border)
2981 border_x = 0
2982 else:
2983 # Pillarboxing
2984 border_y = 0
2985 max_border = scale_x_n
2986 border_x = testGen.randInt(low=0, high=max_border)
2987
2988 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2989 offset = (offset_y, offset_x)
2990 border = (border_y, border_x)
2991
2992 return scale, offset, border
2993
2994 def get_upscale_downscale_params():
2995 valid_params = False
2996 while not valid_params:
2997 upscale = testGen.rng.choice((False, True))
2998
2999 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
3000 origin_sampling = testGen.rng.choice((False, True))
3001
3002 if upscale:
3003 shift = testGen.randInt(low=1, high=4)
3004 scale_x_d = scale_y_d = 1
3005 scale_x_n = scale_y_n = (
3006 1 << shift if origin_sampling else 2 << shift
3007 )
3008 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
3009 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
3010 else:
3011 scale_x_n = 1
3012 scale_y_n = 1
3013
3014 # Return list of valid scale_*_d values (max value 4) given input dim shape
3015 def get_valid_denom(ifm_dim):
3016 return [x for x in range(1, 5) if ifm_dim % x == 1]
3017
3018 # Generate list of valid downscale values and choose one randomly
3019 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
3020 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
3021
3022 if not valid_scale_y_ds and not valid_scale_x_ds:
3023 # Bad parameters, skip
3024 continue
3025
3026 if not valid_scale_y_ds:
3027 scale_y_d = 1
3028 else:
3029 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
3030
3031 if not valid_scale_x_ds:
3032 scale_x_d = 1
3033 else:
3034 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
3035
3036 border_x = border_y = 0
3037 offset_y = testGen.randInt(0, 16 * scale_y_n)
3038 offset_x = testGen.randInt(0, 16 * scale_x_n)
3039 valid_params = True
3040
3041 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3042 offset = (offset_y, offset_x)
3043 border = (border_y, border_x)
3044 return scale, offset, border
3045
3046 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003047 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
3048 scale = scale_n / scale_d
3049 if scale > max_scale:
3050 factor = scale / max_scale
3051 new_scale_d = math.ceil(scale_d * factor)
3052 assert scale_n / new_scale_d <= max_scale
3053 scale_d = new_scale_d
3054 return scale_d
3055
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003056 # Scale
3057 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
3058 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
3059
3060 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
3061 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
3062
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003063 scale_y_d = fix_scale_to_max_scale(
3064 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3065 )
3066 scale_x_d = fix_scale_to_max_scale(
3067 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
3068 )
3069
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003070 # Offsets and border within the scale
3071 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
3072 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
3073 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
3074 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
3075
3076 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
3077 offset = (offset_y, offset_x)
3078 border = (border_y, border_x)
3079 return scale, offset, border
3080
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003081 def get_level_8k_params():
3082 # Create 64x scale - 64/1 to 2048/32
3083 scale_d = testGen.randInt(
3084 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
3085 )
3086 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
3087 # Create half to fifth scaling
3088 scale_d_alt = testGen.randInt(low=2, high=6)
3089 scale_n_alt = 1
3090 switch = testGen.rng.choice((False, True))
3091 if switch:
3092 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
3093 else:
3094 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
3095
3096 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
3097 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
3098 offset = (offset_y, offset_x)
3099 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
3100 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
3101 border = (border_y, border_x)
3102 return scale, offset, border
3103
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003104 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003105 # Exclude illegal {mode, type} configurations. Pick legal output types
3106 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3107 outputDTypeList = [DType.INT8]
3108 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3109 outputDTypeList = [DType.INT16]
3110 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3111 outputDTypeList = [DType.INT32]
3112 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3113 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003114 elif dtype == DType.FP16:
3115 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003116 elif dtype == DType.BF16:
3117 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003118 elif dtype == DType.FP32:
3119 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003120 elif error_name == ErrorIf.WrongInputType:
3121 # If an incorrect input type is used then we set a 'correct'
3122 # output type to avoid other errors
3123 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3124 else:
3125 continue
3126
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003127 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3128
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003129 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003130 perm = 0
3131 while perm < testGen.args.num_rand_permutations:
3132 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003133 if not testGen.args.level8k:
3134 _rnd_param_fn = testGen.rng.choice(
3135 (
3136 get_rand_params,
3137 get_upscale_downscale_params,
3138 get_aspect_ratio_resize_params,
3139 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003140 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003141 scale, offset, border = _rnd_param_fn()
3142 else:
3143 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003144
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003145 # Expand params for bounds-checking
3146 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3147 (offset_y, offset_x) = offset
3148 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003149
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003150 # Make sure output dimensions OH and OW are integers
3151 partial_output_y = (
3152 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3153 )
3154 partial_output_x = (
3155 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3156 )
3157 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003158 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003159 if (
3160 partial_output_y % scale_y_d == 0
3161 and partial_output_x % scale_x_d == 0
3162 ):
3163 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003164 if perm > 0:
3165 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003166 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003167 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003168 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003169 while partial_output_y % scale_y_d != 0:
3170 scale_y_d -= 1
3171 while partial_output_x % scale_x_d != 0:
3172 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003173 # Make sure we are still within max scaling
3174 if (
3175 scale_y_n / scale_y_d
3176 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3177 scale_x_n / scale_x_d
3178 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3179 # Skip the test as it is using too large a scaling factor
3180 if perm > 0:
3181 perm += 1
3182 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003183
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003184 output_y = partial_output_y // scale_y_d + 1
3185 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003186
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003187 if (
3188 output_y >= testGen.args.max_resize_output_dim
3189 or output_x >= testGen.args.max_resize_output_dim
3190 ) and error_name is None:
3191 # Skip positive test if output dim will be too high
3192 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003193 if not testGen.args.level8k or perm > 0:
3194 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003195 continue
3196
3197 if (
3198 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003199 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003200 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003201 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003202 ):
3203 # Output dimensions out of scope
3204 if error_name is not None and perm > 0:
3205 # As long as we have one ERROR_IF test, don't worry
3206 # about creating all the other permutations
3207 perm += 1
3208 continue
3209
3210 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3211 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003212 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003213 and output_y - scale_y_d < 1
3214 )
3215 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003216 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003217 and output_x - scale_x_d < 1
3218 )
3219 ):
3220 # Can't create a negative test with these params as it
3221 # will create invalid output size
3222 if perm > 0:
3223 perm += 1
3224 continue
3225
3226 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3227 offset = [offset_y, offset_x]
3228 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003229
3230 # Common for all data types
3231 if error_name is not None:
3232 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003233 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003234 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003235 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003236 outputDTypeNew,
3237 ) = TosaErrorIfArgGen.eiResizeErrorIf(
3238 testGen,
3239 error_name,
3240 mode,
3241 dtype,
3242 shapeList,
3243 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003244 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003245 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003246 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003247 )
3248 else:
3249 outputDTypeNew = outputDType
3250
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003251 arg_to_append = (
3252 arg_str.format(
3253 "N" if mode == ResizeMode.NEAREST else "B",
3254 testGen.typeStr(outputDTypeNew),
3255 scale[0],
3256 scale[1],
3257 scale[2],
3258 scale[3],
3259 offset[0],
3260 offset[1],
3261 border[0],
3262 border[1],
3263 ),
3264 [
3265 mode,
3266 scale,
3267 offset,
3268 border,
3269 dtype,
3270 outputDTypeNew,
3271 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003272 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003273 if arg_to_append in arg_list:
3274 # Skip already generated test params
3275 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003276
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003277 # Valid permutation
3278 perm += 1
3279 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003280 return arg_list
3281
3282 @staticmethod
3283 def agTable(testGen, opName, shapeList, dtype, error_name=None):
3284 arg_list = []
3285
3286 if dtype == DType.INT8:
3287 table = np.int32(
3288 testGen.rng.integers(low=-128, high=128, size=[256])
3289 ).tolist()
3290 else: # INT16
3291 table = np.int32(
3292 testGen.rng.integers(low=-32768, high=32768, size=[513])
3293 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003294 # Make sure all slopes are within REQUIRE min/max 16-bit int
3295 for idx in range(len(table) - 1):
3296 slope = table[idx + 1] - table[idx]
3297 # Alter the next table entry to force the slope to be ok
3298 if slope > 32767:
3299 table[idx + 1] -= slope - 32767
3300 if slope < -32768:
3301 table[idx + 1] -= slope + 32768
3302 slope = table[idx + 1] - table[idx]
3303 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003304 arg_list.append(
3305 (
3306 "",
3307 [table],
3308 )
3309 )
3310 return arg_list
3311
3312 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
3313 # CondIf generates the condition values here.
3314 # Convert to tensors in the build function, along with the
3315 # then and else blocks
3316 arg_list = []
3317
3318 for c in [False, True]:
3319 arg_list.append(("cond{}".format(int(c)), [c]))
3320
3321 return arg_list
3322
3323 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
3324 # While loop: 0 iterations, 1, more than 1
3325 arg_list = []
3326
3327 for iter in [0, 1, 4]:
3328 arg_list.append(("iter{}".format(iter), [iter]))
3329
3330 return arg_list