blob: 86414994fd899b7a7d8abfebdad296f07c70473f [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
19
20class TosaQuantGen:
21 """QuantizedInfo random generator helper functions.
22
23 Specify with 'qgen': in the operator defintion.
24 """
25
26 def __init__(self):
27 pass
28
29 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010031
32 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010033 if testGen.args.zeropoint is not None:
34 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035 return testGen.randInt(-128, 128)
36 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010037 if testGen.args.zeropoint is not None:
38 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 return testGen.randInt(0, 256)
40 elif error_name in [
41 ErrorIf.InputZeroPointNotZero,
42 ErrorIf.WeightZeroPointNotZero,
43 ErrorIf.OutputZeroPointNotZero,
44 ]:
45 zero_point = testGen.randInt(-128, 128)
46 if zero_point == 0:
47 zero_point = 1
48 return zero_point
49 return 0
50
51 @staticmethod
52 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010053 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000054 qinfo = [
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 TosaQuantGen.getZeroPoint(testGen, dtype),
57 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010058 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000059 qinfo = [
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
62 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010063 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000064 qinfo = [
65 TosaQuantGen.getZeroPoint(testGen, dtype),
66 TosaQuantGen.getZeroPoint(testGen, dtype),
67 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068 return qinfo
69
70 @staticmethod
71 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010072 if isinstance(dtype_or_dtypeList, list):
73 # a list of [input, weights, accumulator] dtypes
74 dtypeList = dtype_or_dtypeList
75 else:
76 # an int, [input, weights, accumulator] dtypes are the same
77 dtypeList = [dtype_or_dtypeList] * 3
78
79 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000080 qinfo = [
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
82 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
83 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000085 qinfo = [
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
87 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
88 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000090 qinfo = [
91 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
92 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
93 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 return qinfo
95
96 @staticmethod
97 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000099 qinfo = [
100 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
101 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
102 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000104 qinfo = [
105 TosaQuantGen.getZeroPoint(testGen, dtype),
106 TosaQuantGen.getZeroPoint(testGen, dtype),
107 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108 return qinfo
109
110 @staticmethod
111 def computeMultiplierAndShift(scaleFp, scale32):
112 # Derived from computeMultiplierAndShiftTosaScale32
113 # Provide a floating-point scaling factor and the scale32 parameter
114 # to compute the multiplier and shift
115
116 if scale32:
117 scaleBits = 31
118 else:
119 scaleBits = 15
120
121 m, shift = math.frexp(scaleFp)
122
123 if scaleFp < 0.0:
124 m = -m
125
126 multiplier = round(m * (1 << scaleBits))
127 assert multiplier <= (1 << scaleBits)
128
129 if multiplier == (1 << scaleBits):
130 multiplier = multiplier // 2
131 shift = shift + 1
132
133 shift = (-shift) + scaleBits
134 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
135 # scaleFp, scaleBits, m, multiplier, shift))
136
137 # Adjust multiplier such that shift is in allowed value range.
138 if shift == 0:
139 multiplier = multiplier // 4
140 shift = shift + 2
141 elif shift == 1:
142 multiplier = multiplier // 2
143 shift = shift + 1
144 elif shift == 63:
145 multiplier = multiplier * 2
146 shift = shift - 1
147
148 assert multiplier <= (1 << scaleBits)
149 assert shift >= 2 and shift <= 62
150
151 return multiplier, shift
152
153
154class TosaTensorGen:
155 """Tensor generators create a shape list for the placeholder and const tensor
156 data operands for the operator.
157
158 The actual random data is generated separately for each test.
159 """
160
161 def __init__(self):
162 pass
163
164 @staticmethod
165 def tgBasic(testGen, opName, rank, error_name=None):
166 pl, const = opName["operands"]
167 shape = testGen.makeShape(rank)
168
169 # Constrict the overall size of the shape when creating ERROR_IF tests
170 if error_name:
171 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
172
173 shape_list = []
174 for i in range(pl + const):
175 shape_list.append(shape.copy())
176
Luke Huttona4e48ca2023-02-22 11:53:48 +0000177 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 if error_name == ErrorIf.RankMismatch:
179 if rank == 1 and i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
181 elif i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
183
184 return shape_list
185
186 @staticmethod
187 def tgNHWC(testGen, opName, rank, error_name=None):
188 pl, const = opName["operands"]
189
190 if error_name != ErrorIf.WrongRank:
191 assert rank == 4
192
193 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000194 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name and error_name != ErrorIf.MaxDimExceeded:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000207 def tgGather(testGen, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100208 pl, const = opName["operands"]
209
210 assert pl == 2
211 assert const == 0
212 if error_name != ErrorIf.WrongRank:
213 assert rank == 3
214
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000215 values_shape = testGen.makeShape(rank)
216 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100217
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000218 N = values_shape[0]
219 W = testGen.makeDimension()
220 indices_shape = [N, W]
221
222 shape_list = [values_shape, indices_shape]
223 return shape_list
224
225 @staticmethod
226 def tgScatter(testGen, opName, rank, error_name=None):
227 pl, const = opName["operands"]
228
229 assert pl == 3
230 assert const == 0
231 if error_name != ErrorIf.WrongRank:
232 assert rank == 3
233
234 values_in_shape = testGen.makeShape(rank)
235 values_in_shape = testGen.constrictBatchSize(values_in_shape)
236
237 N = values_in_shape[0]
238 K = values_in_shape[1]
239 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100240
Jeremy Johnson194fe312023-12-07 14:17:57 +0000241 # Make sure W is not greater than K, as we can only write each output index
242 # once (having a W greater than K means that you have to repeat a K index)
243 W_min = min(testGen.args.tensor_shape_range[0], K)
244 W_max = min(testGen.args.tensor_shape_range[1], K)
245 W = testGen.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100246
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000247 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100248
249 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000250 shape_list.append(values_in_shape)
251 shape_list.append([N, W]) # indices
252 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100253
254 return shape_list
255
256 @staticmethod
257 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
258 shape = testGen.makeShape(rank)
259
260 pl, const = op["operands"]
261
262 shape_list = []
263
264 # Choose one of the inputs to broadcast
265 # Note: Simplifies OutputShaper code if we don't change first shape for errors
266 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Jerry Ge135c9552023-05-23 20:59:32 +0000267 fuzz_idx = testGen.randInt(0, rank)
268
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100269 for i in range(pl + const):
270 shape_bcast = shape.copy()
271
Jerry Ge135c9552023-05-23 20:59:32 +0000272 # To test broadcasting, the chosen fuzz index dimension should not be 1
273 if shape_bcast[fuzz_idx] == 1:
274 shape_bcast[fuzz_idx] += 1
275
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100276 # If the chosen input, pick a random index to broadcast
277 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000278 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100279 # Add one rank to the shape (or more for rank of 1)
280 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
281 shape_bcast = np.concatenate(
282 (shape_bcast, testGen.makeShape(extra_ranks))
283 )
284 if rank != 1:
285 # Either keep the extra rank, or remove it
286 new_len = testGen.rng.choice([-2, len(shape_bcast)])
287 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000288 elif error_name == ErrorIf.BroadcastShapesMismatch:
289 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100290 else:
291 shape_bcast[fuzz_idx] = 1
292
293 shape_list.append(shape_bcast)
294
295 return shape_list
296
297 @staticmethod
298 def tgConv2D(testGen, op, rank, error_name=None):
299 pl, const = op["operands"]
300
301 if error_name != ErrorIf.WrongRank:
302 assert rank == 4
303
304 # IFM dimensions are NHWC
305 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000306 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100307
308 # Constrict the overall size of the shape when creating ERROR_IF tests
309 if error_name:
310 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
311 ifm_shape, max_dim=24, max_items=10000
312 )
313
314 # Get the filter height/width from the operator parameters
315 filter_hw = op["filter"]
316
317 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000318 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100319
320 # The filter dimensions are OHWI
321 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
322
323 # The bias is OC
324 bias_shape = np.asarray([ofm_depth])
325
326 return [ifm_shape, filter_shape, bias_shape]
327
328 @staticmethod
329 def tgConv3D(testGen, op, rank, error_name=None):
330 pl, const = op["operands"]
331
332 if error_name != ErrorIf.WrongRank:
333 assert rank == 5
334
335 # IFM dimensions are NDHWC
336 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000337 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100338
339 # Constrict the overall size of the shape when creating ERROR_IF tests
340 if error_name:
341 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
342 ifm_shape, max_dim=24, max_items=10000
343 )
344
345 # Get the filter depth/height/width from the operator parameters
346 filter_dhw = op["filter"]
347
348 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000349 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100350
351 # The filter dimensions are ODHWI
352 filter_shape = np.asarray(
353 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
354 )
355
356 # The bias is OC
357 bias_shape = np.asarray([ofm_channel])
358
359 return [ifm_shape, filter_shape, bias_shape]
360
361 @staticmethod
362 def tgTransposeConv2D(testGen, op, rank, error_name=None):
363 pl, const = op["operands"]
364
365 if error_name != ErrorIf.WrongRank:
366 assert rank == 4
367
368 # IFM dimensions are NHWC
369 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000370 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100371
372 # Constrict the overall size of the shape when creating ERROR_IF tests
373 if error_name:
374 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
375 ifm_shape, max_dim=24, max_items=10000
376 )
377
378 # Get the filter height/width from the operator parameters
379 filter_hw = op["filter"]
380
381 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000382 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100383
384 # The filter dimensions are OHWI
385 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
386
387 # The bias is OC
388 bias_shape = np.asarray([ofm_depth])
389
390 return [ifm_shape, filter_shape, bias_shape]
391
392 @staticmethod
393 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
394 pl, const = op["operands"]
395
396 if error_name != ErrorIf.WrongRank:
397 assert rank == 4
398 assert pl == 1 and const == 2
399
400 # IFM dimensions are NHWC
401 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000402 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100403
404 # Constrict the overall size of the shape when creating ERROR_IF tests
405 if error_name:
406 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
407 ifm_shape, max_dim=24, max_items=10000
408 )
409
410 # Get the filter height/width from the operator parameters
411 # Filter is KH, HW, C, M
412 filter_hw = op["filter"]
413
414 # Generate a random OFM depth, but don't let it get too big because
415 # the output depth is M * C
416 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000417 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100418 ) + 1
419
420 # The filter dimensions are HWCM
421 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
422
423 # The bias is M * C
424 bias_shape = np.asarray([ifm_shape[3] * filter_m])
425
426 return [ifm_shape, filter_shape, bias_shape]
427
428 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000429 def tgFFT2d(testGen, op, rank, error_name=None):
430 pl, const = op["operands"]
431
432 if error_name != ErrorIf.WrongRank:
433 assert rank == 3
434 assert pl == 2 and const == 0
435
436 # IFM dimensions are NHW
437 ifm_shape = testGen.makeShape(rank)
438
439 # Select nearest lower power of two from input height and width
440 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
441 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
442
443 # Constrict the overall size of the shape when creating ERROR_IF tests
444 if error_name:
445 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
446
447 # Generate an invalid kernel that is not a power of two
448 if error_name == ErrorIf.KernelNotPowerOfTwo:
449 inc_h = 2 if ifm_shape[1] == 1 else 1
450 inc_w = 2 if ifm_shape[2] == 1 else 1
451 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
452 selected_inc = testGen.rng.choice(inc_choices)
453 ifm_shape[1] += selected_inc[0]
454 ifm_shape[2] += selected_inc[1]
455
456 ifm_shape = testGen.constrictBatchSize(ifm_shape)
457
458 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
459 if error_name == ErrorIf.FFTInputShapeMismatch:
460 modify_shape = testGen.rng.choice([0, 1])
461 # Only modify kernel (H, W)
462 modify_dim = testGen.rng.choice([1, 2])
463 ifm_shapes[modify_shape][modify_dim] *= 2
464
465 return [ifm_shapes[0], ifm_shapes[1]]
466
467 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000468 def tgRFFT2d(testGen, op, rank, error_name=None):
469 pl, const = op["operands"]
470
471 if error_name != ErrorIf.WrongRank:
472 assert rank == 3
473 assert pl == 1 and const == 0
474
475 # IFM dimensions are NHW
476 ifm_shape = testGen.makeShape(rank)
477
478 # Select nearest lower power of two from input height and width
479 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
480 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
481
482 # Constrict the overall size of the shape when creating ERROR_IF tests
483 if error_name:
484 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
485
486 # Generate an invalid kernel that is not a power of two
487 if error_name == ErrorIf.KernelNotPowerOfTwo:
488 # We must increment by 2 if current size is 1
489 inc_h = 2 if ifm_shape[1] == 1 else 1
490 inc_w = 2 if ifm_shape[2] == 1 else 1
491 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
492 selected_inc = testGen.rng.choice(inc_choices)
493 ifm_shape[1] += selected_inc[0]
494 ifm_shape[2] += selected_inc[1]
495
James Ward30124a82023-02-02 14:56:33 +0000496 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000497
498 return [ifm_shape]
499
500 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100501 def tgFullyConnected(testGen, op, rank, error_name=None):
502 pl, const = op["operands"]
503
504 if error_name != ErrorIf.WrongRank:
505 assert rank == 2
506
507 input_shape = testGen.makeShape(rank)
508
509 # Constrict the overall size of the shape when creating ERROR_IF tests
510 if error_name:
511 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
512
513 filter_oc = testGen.rng.integers(
514 low=testGen.args.tensor_shape_range[0],
515 high=testGen.args.tensor_shape_range[1],
516 size=1,
517 )[0]
518 filter_shape = np.asarray([filter_oc, input_shape[1]])
519
520 bias_shape = np.asarray([filter_oc])
521
522 return [input_shape, filter_shape, bias_shape]
523
524 @staticmethod
525 def tgMatmul(testGen, op, rank, error_name=None):
526 pl, const = op["operands"]
527
528 if error_name != ErrorIf.WrongRank:
529 assert rank == 3
530 assert pl == 2 and const == 0
531
532 a_shape = testGen.makeShape(rank)
533
534 # Constrict the overall size of the shape when creating ERROR_IF tests
535 if error_name:
536 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
537
538 # Get a random number for b_oc even if target shape is defined
539 b_oc = np.int32(
540 testGen.rng.integers(
541 low=testGen.args.tensor_shape_range[0],
542 high=testGen.args.tensor_shape_range[1],
543 size=1,
544 )
545 )[0]
546 # If N or H is large let b_oc be 1 to reduce output tensor size
547 if max(a_shape) > 1000:
548 b_oc = 1
549
550 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
551 return [a_shape, b_shape]
552
553 @staticmethod
554 def tgConcat(testGen, opName, rank, error_name=None):
555 pl, const = opName["operands"]
556 shape = testGen.makeShape(rank)
557
558 # Create extra tensors to concat.
559 # Take into account value of pl when getting maximum number of concats
560 num_tensors = testGen.randInt(0, 4)
561 shape_list = []
562 for i in range(pl + const + num_tensors):
563 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
564 remove = testGen.rng.choice([True, False])
565 wrongShape = shape.copy()
566
567 if remove and len(shape) > 1:
568 wrongShape = wrongShape[1:]
569 else:
570 wrongShape = list(wrongShape)
571 wrongShape.append(testGen.rng.integers(1, 10))
572
573 shape_list.append(wrongShape)
574 else:
575 shape_list.append(shape.copy())
576
577 return shape_list
578
579 @staticmethod
580 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
581 if error_name in [
582 ErrorIf.AxisSmallerZero,
583 ErrorIf.AxisLargerRank,
584 ErrorIf.ConcatInputRankMismatch,
585 ]:
586 return shapeList
587
588 # Split concat shape along axis to allow for multiple const inputs
589 # without making too many large tensors
590 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
591 # If axis can't be split we still need to invalidate other dimensions
592 if error_name == ErrorIf.ConcatInputDimMismatch:
593 for shape in shapeList[1:]:
594 # Negative test shapeLists are created individually for each test,
595 # so no need to copy the shape before altering it.
596 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
597 return shapeList
598
599 # Create copy of shape we are going to split (so we don't alter shapeList)
600 shape = shapeList[0].copy()
601 # Add original shape as first input
602 new_shapeList = [shape.copy()]
603 length_on_axis = shape[axis]
604 remaining_length = length_on_axis
605 for i in range(len(shapeList) - 2):
606 # Calculate split on axis and remaining value
607 split_shape_val = int(shape[axis] / 2)
608 remaining_length = remaining_length - split_shape_val
609
610 # Append new shape, and set remaining shape
611 shape[axis] = split_shape_val
612 new_shapeList.append(shape.copy())
613
614 # invalidate dimensions
615 if error_name == ErrorIf.ConcatInputDimMismatch:
616 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
617 else:
618 shape[axis] = remaining_length
619
620 if i == len(shapeList) - 3:
621 new_shapeList.append(shape.copy())
622
623 return new_shapeList
624
625
626class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100627 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100628
629 def __init__(self):
630 pass
631
Jeremy Johnson1271c442023-09-05 11:39:26 +0100632 class TVGInfo:
633 """Enhanced tensor values information including data gen dict."""
634
635 def __init__(self, tensorList, dataGenDict):
636 self.tensorList = tensorList
637 self.dataGenDict = dataGenDict
638
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100639 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000640 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100641 pCount, cCount = op["operands"]
642
643 tens = []
644 tens.extend(
645 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
646 )
647 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
648
649 return tens
650
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100651 # Default high value for random numbers
652 TVG_FLOAT_HIGH_VALUE = {
653 DType.FP32: (1 << 128) - (1 << (127 - 23)),
654 DType.FP16: (1 << 16) - (1 << (15 - 10)),
655 DType.BF16: (1 << 128) - (1 << (127 - 7)),
656 }
657
Jeremy Johnson30476252023-11-20 16:15:30 +0000658 # Default lowest normal values for random numbers
659 TVG_FLOAT_LOW_VALUE = {
660 DType.FP32: np.exp2(-126),
661 DType.FP16: np.exp2(-14),
662 DType.BF16: np.exp2(-126),
663 }
664
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100665 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +0000666 def _get_data_range(testGen, dtype, highValueLookup, lowValueLookup=None):
667 # Return a tuple of (low,high) data range values for the given data
668 # type using a combination of per operator table limits, data limits
669 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000670 if dtype in highValueLookup:
Jeremy Johnson30476252023-11-20 16:15:30 +0000671 type_range = testGen.getDTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000672 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000673 if lowValueLookup is not None and dtype in lowValueLookup:
674 low_val = lowValueLookup[dtype]
675 else:
676 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000677 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000678 # respecting the default ranges if more/less than the low/high
679 # values
680 data_range = (
681 max(low_val, type_range[0]),
682 min(high_val, type_range[1]),
683 )
684 if data_range[0] > data_range[1]:
685 # Invalid data range from low to high created due to user
686 # constraints revert to using internal ranges as they are
687 # known to work
688 msg = f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
689 warnings.warn(msg)
690 data_range = (low_val, high_val)
691 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000692 return None
693
694 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100695 def tvgLazyGenDefault(
696 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
697 ):
698 # Variable inputs versus constants
699 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000700 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100701
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100702 if (
703 error_name is not None
704 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100705 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100706 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000707 # Fall back to internal data gen when dealing with unsupported types or ops
708 data_range = argsDict["data_range"] if "data_range" in argsDict else None
709 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000710 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000711 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000712 if "data_range_list" in argsDict:
713 data_range = argsDict["data_range_list"][idx]["range"]
714 roundMode = (
715 "round" in argsDict["data_range_list"][idx]
716 and argsDict["data_range_list"][idx]["round"] is True
717 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000718 if data_range is not None and dtype not in (
719 DType.FP16,
720 DType.FP32,
721 DType.BF16,
722 ):
723 # Change from inclusive to exclusive range
724 data_range = (data_range[0], data_range[1] + 1)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000725 # Ignore lazy data gen option and create data array using any range limits
726 arr = testGen.getRandTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000727 if roundMode:
728 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000729 if idx < pCount:
730 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
731 else:
732 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100733
Jeremy Johnson1271c442023-09-05 11:39:26 +0100734 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
735
736 # Create data generator meta-data
737 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100738 tens_data = {
739 "version": "0.1",
740 "tensors": {},
741 }
742 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100743 for idx, shape in enumerate(shapeList):
744
745 tens_meta = {}
746 tens_meta["generator"] = gtu.DataGenType(dg_type).name
747 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
748 tens_meta["shape"] = [int(i) for i in shape]
749 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100750 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100751
752 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100753 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100754 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100755 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100756
757 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
758 info = {}
759 # TODO - generate seed for this generator based on test
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100760 info["rng_seed"] = 42
Jeremy Johnson30476252023-11-20 16:15:30 +0000761
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000762 data_range = None
Jeremy Johnson30476252023-11-20 16:15:30 +0000763 if "data_range_list" in argsDict:
764 data_range = argsDict["data_range_list"][idx]["range"]
765 if "round" in argsDict["data_range_list"][idx]:
766 info["round"] = argsDict["data_range_list"][idx]["round"]
767 elif "data_range" in argsDict:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100768 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000769
770 if data_range is None:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100771 data_range = testGen.getDTypeRange(
772 dtypeList[idx], high_inclusive=True
773 )
774 info["range"] = [str(v) for v in data_range]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100775 tens_meta["pseudo_random_info"] = info
776 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
777 info = {}
778 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100779 info["ks"] = int(argsDict["ks"])
780 if "acc_type" in argsDict:
781 # Convert type number into JSON name
782 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
783 "json"
784 ]
785 if "kernel" in argsDict:
786 info["kernel"] = [int(k) for k in argsDict["kernel"]]
787 if "axis" in argsDict:
788 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100789 tens_meta["dot_product_info"] = info
790 else:
791 # TODO - other data gen type
792 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100793
794 # Using the finished generate config meta data - generate the data if
795 # needed and assign a tensor name from the serializer
796
797 # Need to generate data when not lazy or for the bias tensor as we need
798 # to work out if the bias data is non-zero for compliance
799 if not testGen.args.lazy_data_gen or (
800 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
801 ):
802 # Give this tensor a temporary name until we get one from the serializer
803 temp_name = f"placeholder_{idx}"
804 dg_tens_meta[temp_name] = tens_meta
805 # Create data now using the temporary name to access meta details
806 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
807 # Remove the item as we will give it the correct name later
808 del dg_tens_meta[temp_name]
809
810 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
811 # The KS value used by compliance verification is altered when the
812 # bias data is non-zero
813 if max(abs(data)) > 0.0:
814 argsDict["ksb"] = argsDict["ks"] + 1
815
816 if testGen.args.lazy_data_gen:
817 data = None
818
819 if tens_meta["input_type"] == "VARIABLE":
820 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
821 else:
822 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
823
824 tens_ser_list.append(tens)
825 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100826 dg_tens_meta[tens.name] = tens_meta
827
Jeremy Johnson1271c442023-09-05 11:39:26 +0100828 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
829
830 @staticmethod
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000831 def tvgNegate(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100832 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000833 # Integer test
834 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100835 pCount, cCount = op["operands"]
836 assert (
837 pCount == 1 and cCount == 0
838 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100839 # Must create tensors with values within accumulator (int32) negatable
840 # range
841 max_val = (1 << 31) - 1
842 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100843 arr = np.int32(
844 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
845 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000846 tens_ser_list = []
847 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100848 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
849 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000850 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100851 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000852 # ERROR_IF or floating point test
853 return TosaTensorValuesGen.tvgLazyGenDefault(
854 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100855 )
856
Jeremy Johnson30476252023-11-20 16:15:30 +0000857 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000858 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
859 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
860 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
861 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
862 }
863
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100864 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000865 def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100866 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000867 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100868 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000869 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100870 pCount, cCount = op["operands"]
871 assert (
872 pCount == 2 and cCount == 0
873 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000874 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100875 add = op["op"] == Op.ADD
876 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
877 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
878 if add:
879 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
880 else:
881 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
882
883 # Work out the saturation limits
884 max_i32 = (1 << 31) - 1
885 min_i32 = -(1 << 31)
886 max_arr = np.full(shapeList[1], max_i32)
887 min_arr = np.full(shapeList[1], min_i32)
888
889 # Find how much values exceed the maximum/minimums
890 sat_max_arr = np.maximum(res_arr - max_arr, 0)
891 sat_min_arr = np.minimum(res_arr - min_arr, 0)
892
893 if not add:
894 # Swap saturation values and negate values as we need to perform opposite operations
895 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
896
897 # Create new array of unsaturated values by clipping values as needed
898 b_unsat_arr = b_arr
899 if (sat_max_arr != 0).any():
900 # Clip values that cause saturation
901 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
902 # Reduce axes in unsaturated tensor to match original tensor
903 for axis, dim in enumerate(b_arr.shape):
904 if dim != b_unsat_arr.shape[axis]:
905 assert (
906 dim == 1
907 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
908 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
909
910 if (sat_min_arr != 0).any():
911 # Clip values that cause saturation
912 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
913 # Reduce axes in unsaturated tensor to match original tensor
914 for axis, dim in enumerate(b_arr.shape):
915 if dim != b_unsat_arr.shape[axis]:
916 assert (
917 dim == 1
918 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
919 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
920
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000921 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100922 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
923 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000924 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100925 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
926 )
927
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000928 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100929 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000930 # ERROR_IF or floating point test
931 data_range = TosaTensorValuesGen._get_data_range(
932 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
933 )
934 if data_range:
935 argsDict["data_range"] = data_range
936
937 return TosaTensorValuesGen.tvgLazyGenDefault(
938 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100939 )
940
941 @staticmethod
942 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000943 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100944 ):
945 if dtypeList[0] in (
946 DType.INT32,
947 DType.INT16,
948 DType.INT8,
949 ):
950 # Limit input tensors with cond_if_binary or while_loop to stop
951 # saturation of add/sub ops with int32 and keep all logical shift
952 # values between 0 to 31 for int16 or int8
953 pCount, cCount = op["operands"]
954 pRemain = pCount
955 placeholders = []
956 for idx, shape in enumerate(shapeList[:]):
957 if dtypeList[0] == DType.INT32:
958 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
959 else:
960 arr = np.int32(
961 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
962 )
963 if pRemain > 0:
964 placeholders.append(
965 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
966 )
967 pRemain -= 1
968 else:
969 placeholders.append(
970 testGen.ser.addConst(shape, dtypeList[idx], arr)
971 )
972
973 return placeholders
974 else:
975 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000976 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100977 )
978
979 @staticmethod
980 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000981 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100982 ):
983 pCount, cCount = op["operands"]
984 # Force value of operand[1] to be within [0, num_bits]
985 assert (
986 pCount == 2 and cCount == 0
987 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
988
989 placeholders = []
990 for idx, shape in enumerate(shapeList[:]):
991 if idx == 1:
992 if dtypeList[idx] == DType.INT8:
993 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
994 elif dtypeList[idx] == DType.INT16:
995 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
996 elif dtypeList[idx] == DType.INT32:
997 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
998 elif error_name == ErrorIf.WrongInputType:
999 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
1000 else:
1001 raise Exception("OpArithmeticRightShift: invalid input dtype")
1002 else:
1003 arr = testGen.getRandTensor(shape, dtypeList[idx])
1004 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
1005
1006 return placeholders
1007
1008 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001009 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001010 # Set datatype of condition tensor to boolean
1011 dtypeList[0] = DType.BOOL
1012
1013 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001014 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001015 )
1016
1017 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001018 def tvgIntDiv(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001019 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001020 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001021 pCount, cCount = op["operands"]
1022 assert (
1023 pCount == 2 and cCount == 0
1024 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1025
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001026 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001027
1028 # Two invalid cases for Op.INTDIV:
1029 # 1. divisor == 0
1030 # 2. dividend == -(1<<31) and divisor == -1
1031 while True:
1032 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1033 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
1034
1035 if (divisor_arr == 0).any():
1036 continue
1037
1038 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1039 continue
1040
1041 break
1042
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001043 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001044 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1045 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001046 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001047 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1048 )
1049
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001050 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001051 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001052 return TosaTensorValuesGen.tvgLazyGenDefault(
1053 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001054 )
1055
Jeremy Johnson30476252023-11-20 16:15:30 +00001056 # Set the MUL data range to the square root of the largest value
1057 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001058 TVG_FLOAT_HIGH_VALUE_MUL = {
1059 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1060 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1061 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1062 }
1063
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001064 @staticmethod
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001065 def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1066 if error_name is not None or dtypeList[0] in (
1067 DType.FP16,
1068 DType.BF16,
1069 DType.FP32,
1070 ):
1071 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001072 data_range = TosaTensorValuesGen._get_data_range(
1073 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
1074 )
1075 if data_range:
1076 argsDict["data_range"] = data_range
1077
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001078 return TosaTensorValuesGen.tvgLazyGenDefault(
1079 testGen, opName, dtypeList, shapeList, argsDict, error_name
1080 )
1081 else:
1082 # Integer test
1083 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001084 pCount, cCount = op["operands"]
1085 assert (
1086 pCount == 2 and cCount == 0
1087 ), "Op.MUL must have 2 placeholders, 0 consts"
1088
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001089 tens_ser_list = []
1090
1091 # Make sure multiply result in int32 range
1092 shift = argsDict["shift"]
1093 if dtypeList[0] == DType.INT8:
1094 num_bits = 8
1095 elif dtypeList[0] == DType.INT16:
1096 num_bits = 16
1097 elif dtypeList[0] == DType.INT32:
1098 num_bits = 32
1099 elif error_name == ErrorIf.WrongInputType:
1100 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001101 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001102 raise Exception("OpMul: invalid input dtype")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001103
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001104 for idx, shape in enumerate(shapeList[:]):
1105 low = -(2 ** (num_bits - 1))
1106 high = (2 ** (num_bits - 1)) - 1
1107
1108 a_arr = np.int32(
1109 testGen.rng.integers(low=low, high=high, size=shapeList[0])
1110 )
1111 b_arr = np.int32(
1112 testGen.rng.integers(low=low, high=high, size=shapeList[1])
1113 )
1114
1115 i = 0
1116 while True:
1117
1118 a_arr_64 = a_arr.astype(np.int64)
1119 b_arr_64 = b_arr.astype(np.int64)
1120
1121 if shift > 0:
1122 rounding = 1 << (shift - 1)
1123 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001124 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001125 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001126
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001127 if (result_arr > -(2**31)).all() and (
1128 result_arr <= ((2**31) - 1)
1129 ).all():
1130 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001131
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001132 i = i + 1
1133 a_arr = a_arr // 2
1134 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001135
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001136 tens_ser_list.append(
1137 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001138 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001139 tens_ser_list.append(
1140 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1141 )
1142
1143 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001144
1145 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001146 def tvgConcat(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001147 count = len(shapeList) - testGen.args.num_const_inputs_concat
1148 if count < 1:
1149 count = 1
1150 if testGen.args.num_const_inputs_concat == 0:
1151 count = len(shapeList)
1152
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001153 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001154 testGen, shapeList, argsDict["axis"], error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001155 )
1156
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001157 tens_ser_list = []
1158 tens_ser_list.extend(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001159 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
1160 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001161 tens_ser_list.extend(
1162 testGen.buildConstTensors(shapeList[count:], dtypeList[count:])
1163 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001164
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001165 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001166
1167 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001168 def tvgLogicalShift(
1169 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1170 ):
1171 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001172 pCount, cCount = op["operands"]
1173 assert (
1174 pCount == 2 and cCount == 0
1175 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
1176 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1177 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001178 tens_ser_list = []
1179 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001180 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1181 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001182 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001183 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1184 )
1185
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001186 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001187
1188 @staticmethod
Jeremy Johnsona0150012023-11-15 15:52:06 +00001189 def tvgEqual(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1190 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1191 # Integer
1192 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001193 pCount, cCount = op["operands"]
1194 assert (
1195 pCount == 2 and cCount == 0
1196 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001197
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001198 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1199 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001200
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001201 # Using random numbers means that it will be very unlikely that
1202 # there are any matching (equal) values, therefore force that
1203 # there are twice the number of matching values as the tensor rank
1204 for num in range(0, len(shapeList[0]) * 2):
1205 a_index = []
1206 b_index = []
1207 # Choose an index in each axis for the whole shape
1208 for axis in range(0, len(shapeList[0])):
1209 # Index can be up to the largest dimension in both shapes
1210 index = np.int32(
1211 testGen.rng.integers(
1212 0, max(shapeList[0][axis], shapeList[1][axis])
1213 )
1214 )
1215 # Reduce the index down to a shape's dim for broadcasting
1216 a_index.append(min(shapeList[0][axis] - 1, index))
1217 b_index.append(min(shapeList[1][axis] - 1, index))
1218
1219 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1220
Jeremy Johnsona0150012023-11-15 15:52:06 +00001221 tens_ser_list = []
1222 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001223 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1224 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001225 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001226 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1227 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001228 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001229 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001230 # ERROR_IF or floating point test
1231 return TosaTensorValuesGen.tvgLazyGenDefault(
1232 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001233 )
1234
1235 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001236 def tvgReduceSum(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001237 dtype = dtypeList[0]
1238 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001239 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001240 pCount, cCount = op["operands"]
1241 assert (
1242 pCount == 1 and cCount == 0
1243 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1244 # Limit values so that the sum cannot exceed the range of an int32 during
1245 # summation of any axis
1246 range_val = int((1 << 31) / max(shapeList[0]))
1247 values_arr = np.int32(
1248 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1249 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001250 tens_ser_list = []
1251 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001252 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001253 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001254 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001255 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001256 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001257 if (
1258 error_name is None
1259 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1260 ):
1261 # Limit ranges for (non error & non compliance) tests by using
1262 # values that can be summed on any axis to not hit infinity
1263 highval_lookup = {
1264 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1265 / max(shapeList[0])
1266 }
1267 data_range = TosaTensorValuesGen._get_data_range(
1268 testGen, dtype, highval_lookup
1269 )
1270 assert data_range is not None
1271 argsDict["data_range"] = data_range
1272
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001273 return TosaTensorValuesGen.tvgLazyGenDefault(
1274 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001275 )
1276
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001277 @staticmethod
1278 def tvgReduceProduct(
1279 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1280 ):
1281 dtype = dtypeList[0]
1282 if error_name is None:
1283 # Limit ranges for (non error) tests by using
1284 # values that can be multiplied on any axis to not hit infinity
1285 highval_lookup = {
1286 dtype: math.pow(
1287 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1288 1 / max(shapeList[0]),
1289 )
1290 }
1291 data_range = TosaTensorValuesGen._get_data_range(
1292 testGen, dtype, highval_lookup
1293 )
1294 assert data_range is not None
1295 argsDict["data_range"] = data_range
1296
1297 return TosaTensorValuesGen.tvgLazyGenDefault(
1298 testGen, opName, dtypeList, shapeList, argsDict, error_name
1299 )
1300
Jeremy Johnson30476252023-11-20 16:15:30 +00001301 # Set the POW exponent high data range
1302 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1303 DType.FP32: 10.0,
1304 DType.FP16: 10.0,
1305 DType.BF16: 10.0,
1306 }
1307 # POW highest base value (within a safe margin of error) that can be raised
1308 # to +ve exponent that doesn't become Infinity
1309 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1310 DType.FP32: math.floor(
1311 math.pow(
1312 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1313 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1314 )
1315 ),
1316 DType.FP16: math.floor(
1317 math.pow(
1318 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1319 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1320 )
1321 ),
1322 DType.BF16: math.floor(
1323 math.pow(
1324 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1325 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1326 )
1327 ),
1328 }
1329 # POW lowest base value (within a safe margin of error) that can be raised
1330 # to -ve exponent that doesn't become Infinity
1331 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1332 DType.FP32: math.ceil(
1333 math.pow(
1334 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1335 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1336 )
1337 * 1000
1338 )
1339 / 1000,
1340 DType.FP16: math.ceil(
1341 math.pow(
1342 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1343 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1344 )
1345 * 1000
1346 )
1347 / 1000,
1348 DType.BF16: math.ceil(
1349 math.pow(
1350 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1351 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1352 )
1353 * 1000
1354 )
1355 / 1000,
1356 }
1357
1358 @staticmethod
1359 def tvgPow(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1360 if error_name is not None:
1361 return TosaTensorValuesGen.tvgLazyGenDefault(
1362 testGen, opName, dtypeList, shapeList, argsDict, error_name
1363 )
1364 dtype = dtypeList[0]
1365 # Different ranges for POW
1366 test_set = argsDict["s"]
1367 if test_set == 0:
1368 # Positive base with fractional exponent
1369 base_range = TosaTensorValuesGen._get_data_range(
1370 testGen,
1371 dtype,
1372 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1373 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1374 )
1375 exp_range = TosaTensorValuesGen._get_data_range(
1376 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1377 )
1378 exp_round = False
1379 else:
1380 # Integer exponent
1381 exp_range = TosaTensorValuesGen._get_data_range(
1382 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1383 )
1384 exp_round = True
1385 if test_set == 1:
1386 # Positive base
1387 base_range = TosaTensorValuesGen._get_data_range(
1388 testGen,
1389 dtype,
1390 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1391 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1392 )
1393 else:
1394 assert test_set == 2
1395 # Negative base
1396 # Supply new look up tables with negative values
1397 base_range = TosaTensorValuesGen._get_data_range(
1398 testGen,
1399 dtype,
1400 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1401 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1402 )
1403
1404 data_range_list = (
1405 {
1406 "range": base_range,
1407 },
1408 {
1409 "range": exp_range,
1410 "round": exp_round,
1411 },
1412 )
1413 argsDict["data_range_list"] = data_range_list
1414 return TosaTensorValuesGen.tvgLazyGenDefault(
1415 testGen, opName, dtypeList, shapeList, argsDict, error_name
1416 )
1417
1418 @staticmethod
1419 def tvgLogRsqrt(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1420 # LOG & RSQRT data range from lowest expressible positive number to
1421 # largest to avoid NaNs
1422 data_range = TosaTensorValuesGen._get_data_range(
1423 testGen,
1424 dtypeList[0],
1425 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1426 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1427 )
1428 if data_range:
1429 argsDict["data_range"] = data_range
1430
1431 return TosaTensorValuesGen.tvgLazyGenDefault(
1432 testGen, opName, dtypeList, shapeList, argsDict, error_name
1433 )
1434
1435 # Set the EXP data range to the log of the largest to smallest values
1436 # to avoid infinities or making the result zero
1437 TVG_FLOAT_HIGH_VALUE_EXP = {
1438 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1439 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1440 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1441 }
1442 TVG_FLOAT_LOW_VALUE_EXP = {
1443 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1444 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1445 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1446 }
1447
1448 @staticmethod
1449 def tvgExp(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1450 data_range = TosaTensorValuesGen._get_data_range(
1451 testGen,
1452 dtypeList[0],
1453 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1454 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1455 )
1456 if data_range:
1457 argsDict["data_range"] = data_range
1458
1459 return TosaTensorValuesGen.tvgLazyGenDefault(
1460 testGen, opName, dtypeList, shapeList, argsDict, error_name
1461 )
1462
1463 @staticmethod
1464 def tvgFullyConnected(
1465 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1466 ):
1467 dtype = dtypeList[0]
1468 if (
1469 error_name is None
1470 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001471 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001472 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001473 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001474 # Limit ranges for (non error & non compliance) FP tests by using
1475 # values that can be multiplied on any axis to not hit infinity/NaN
1476 IC = shapeList[0][1]
1477 highval_lookup = {
1478 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1479 }
1480 data_range = TosaTensorValuesGen._get_data_range(
1481 testGen, dtype, highval_lookup
1482 )
1483 assert data_range is not None
1484 argsDict["data_range"] = data_range
1485
1486 return TosaTensorValuesGen.tvgLazyGenDefault(
1487 testGen, opName, dtypeList, shapeList, argsDict, error_name
1488 )
1489
Jeremy Johnson708da822023-11-15 16:25:45 +00001490 @staticmethod
1491 def tvgCast(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1492 in_dtype = dtypeList[0]
1493 out_dtype = argsDict["out_type"]
1494 # Create look up to limit input tensor to output type maximums to avoid
1495 # FP infinities and saturation of integers
1496 out_range = testGen.getDTypeRange(out_dtype, high_inclusive=True)
1497 highval_lookup = {in_dtype: out_range[1]}
1498 data_range = TosaTensorValuesGen._get_data_range(
1499 testGen,
1500 in_dtype,
1501 highval_lookup,
1502 )
1503
1504 assert data_range is not None
1505 argsDict["data_range"] = data_range
1506
1507 return TosaTensorValuesGen.tvgLazyGenDefault(
1508 testGen, opName, dtypeList, shapeList, argsDict, error_name
1509 )
1510
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001511 @staticmethod
1512 def tvgGather(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1513 K = shapeList[0][1]
1514
1515 # Fix the type of the indices tensor
1516 dtypeList[1] = DType.INT32
1517
1518 dtype = dtypeList[0]
1519 if not gtu.dtypeIsSupportedByCompliance(dtype):
1520 # Test unsupported by data generator
1521 op = testGen.TOSA_OP_LIST[opName]
1522 pCount, cCount = op["operands"]
1523 assert (
1524 pCount == 2 and cCount == 0
1525 ), "Op.GATHER must have 2 placeholders, 0 consts"
1526
1527 tens_ser_list = []
1528 for idx, shape in enumerate(shapeList):
1529 dtype = dtypeList[idx]
1530 if idx != 1:
1531 arr = testGen.getRandTensor(shape, dtype)
1532 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1533 else:
1534 # Limit data range of indices tensor upto K (exclusive)
1535 arr = testGen.getRandTensor(shape, dtype, (0, K))
1536 # To match old functionality - create indices as CONST
1537 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1538
1539 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1540
1541 else:
1542 # ERROR_IF or floating point test
1543 # Use inclusive values upto index K for indices tensor
1544 data_range_list = (
1545 {"range": None},
1546 {"range": (0, K - 1)},
1547 )
1548 argsDict["data_range_list"] = data_range_list
1549
1550 return TosaTensorValuesGen.tvgLazyGenDefault(
1551 testGen, opName, dtypeList, shapeList, argsDict, error_name
1552 )
1553
1554 @staticmethod
1555 def tvgScatter(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1556 K = shapeList[0][1]
1557 W = shapeList[2][1]
1558
1559 # Work out an indices tensor here with data that doesn't exceed the
1560 # dimension K of the values_in tensor and does NOT repeat the same K
1561 # location as needed by the spec:
1562 # "It is not permitted to repeat the same output index within a single
1563 # SCATTER operation and so each output index occurs at most once."
1564 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1565
1566 # Fix the type of the indices tensor
1567 dtypeList[1] = DType.INT32
1568
1569 dtype = dtypeList[0]
1570 if not gtu.dtypeIsSupportedByCompliance(dtype):
1571 # Test unsupported by data generator
1572 op = testGen.TOSA_OP_LIST[opName]
1573 pCount, cCount = op["operands"]
1574 assert (
1575 pCount == 3 and cCount == 0
1576 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1577
1578 tens_ser_list = []
1579 for idx, shape in enumerate(shapeList):
1580 dtype = dtypeList[idx]
1581 if idx != 1:
1582 arr = testGen.getRandTensor(shape, dtype)
1583 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1584 else:
1585 # Create the indices array
1586 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1587 arr = []
1588 for n in range(shape[0]):
1589 # Get a shuffled list of output indices (0 to K-1) and
1590 # limit length to W
1591 arr.append(testGen.rng.permutation(K)[:W])
1592 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1593 # To match old functionality - create indices as CONST
1594 tens_ser_list.append(
1595 testGen.ser.addConst(shape, dtype, indices_arr)
1596 )
1597
1598 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1599
1600 else:
1601 # ERROR_IF or floating point test
1602 # Use inclusive values upto index K for indices tensor
1603 data_range_list = (
1604 {"range": None},
1605 {"range": (0, K - 1)},
1606 {"range": None},
1607 )
1608 argsDict["data_range_list"] = data_range_list
1609
1610 return TosaTensorValuesGen.tvgLazyGenDefault(
1611 testGen, opName, dtypeList, shapeList, argsDict, error_name
1612 )
1613
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001614
1615class TosaArgGen:
1616 """Argument generators create exhaustive or random lists of attributes for
1617 operators that take attributes or other parameters.
1618
1619 The return value is a list of (descriptive_name, [arglist]) tuples where
1620 the descriptive_name is appended to the test name and the arglist is expanded
1621 as arguments to the operator build function.
1622 """
1623
1624 def __init__(self):
1625 pass
1626
1627 @staticmethod
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001628 def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001629 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001630 if (
1631 error_name is None
1632 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1633 and gtu.dtypeIsSupportedByCompliance(dtype)
1634 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001635 if dtype in [DType.FP16, DType.FP32, DType.BF16]:
1636 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1637 else:
1638 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1639 else:
1640 # Error test or No data generator types listed - assume random
1641 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1642
1643 # Expand arg list with other data generator types
1644 new_arg_list = []
1645 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001646 for arg_str, args_dict in arg_list:
1647 args_dict["dg_type"] = dg_type
Jeremy Johnson1271c442023-09-05 11:39:26 +01001648 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001649 if error_name is None:
1650 num_test_sets = (
1651 args_dict["num_test_sets"]
1652 if "num_test_sets" in args_dict
1653 else 0
1654 )
1655 else:
1656 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001657
1658 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1659 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001660 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001661 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001662 shape_info = (
1663 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1664 if "shape" in args_dict
1665 else ""
1666 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001667 print(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001668 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001669 )
1670 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001671 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001672 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001673 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001674
Jeremy Johnson30476252023-11-20 16:15:30 +00001675 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1676
1677 if num_test_sets > 0:
1678 for s in range(0, num_test_sets):
1679 new_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001680 new_args_dict = args_dict.copy()
1681 new_args_dict["s"] = s
1682 new_arg_list.append((new_arg_str, new_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001683 else:
1684 # Default is a single test
1685 new_arg_list.append((arg_str, args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001686
1687 return new_arg_list
1688
1689 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001690 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1691 """A trivial argument generator for operators that don't take any
1692 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001693 arg_list = TosaArgGen._add_data_generators(
1694 testGen,
1695 opName,
1696 dtype,
1697 [("", {})],
1698 error_name,
1699 )
1700 # Return list of tuples: (arg_str, args_dict)
1701 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001702
1703 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +00001704 def agPow(testGen, opName, shapeList, dtype, error_name=None):
1705 """Pow operator needs different test sets to cover random numbers
1706 without creating NaNs or Infs"""
1707 arg_list = TosaArgGen._add_data_generators(
1708 testGen,
1709 opName,
1710 dtype,
1711 [("", {"num_test_sets": 3})],
1712 error_name,
1713 )
1714 # Return list of tuples: (arg_str, args_dict)
1715 return arg_list
1716
1717 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001718 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1719 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001720 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001721 shape = shapeList[0]
1722
1723 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001724 # Set too small axis
1725 axes = [testGen.rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001726 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001727 # Set too large axis
1728 axes = [testGen.rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001729 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001730 # Create tests for each dimension
1731 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001732
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001733 opid = testGen.TOSA_OP_LIST[opName]["op"]
1734
1735 for a in axes:
1736 args_dict = {"axis": int(a)}
1737 if opid == Op.REDUCE_SUM:
1738 args_dict["dot_products"] = gtu.product(shape)
1739 args_dict["shape"] = shape
1740 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1741 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1742
1743 arg_list.append(("axis{}".format(a), args_dict))
1744
1745 arg_list = TosaArgGen._add_data_generators(
1746 testGen,
1747 opName,
1748 dtype,
1749 arg_list,
1750 error_name,
1751 )
1752 # Return list of tuples: (arg_str, args_dict)
1753 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001754
1755 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001756 def _calculate_sparsity(num_tests, sparsity_factor):
1757 sparsity = num_tests // sparsity_factor + 1
1758 # If there are only a small number of tests, just select them all
1759 if sparsity < 13:
1760 sparsity = 1
1761 # To get a variety of parameter combinations sparsity should not be a
1762 # multiple of 2, 3 or 5
1763 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1764 sparsity += 1
1765 return sparsity
1766
1767 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001768 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001769 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001770 arg_list = []
1771
Jeremy Johnson0c716862023-04-13 17:18:19 +01001772 if testGen.args.level8k and error_name is not None:
1773 # Don't produce negative large tests
1774 return arg_list
1775
1776 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001777 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001778 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001779 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001780
Jeremy Johnson1271c442023-09-05 11:39:26 +01001781 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001782
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001783 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01001784 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001785 depthwise = opName.startswith("depthwise")
1786
1787 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001788 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001789 if error_name != ErrorIf.WrongRank:
1790 assert len(ifm_shape) == rank
1791 assert len(filter_shape) == rank
1792
Jeremy Johnson0c716862023-04-13 17:18:19 +01001793 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001794 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001795 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01001796 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001797 # compliance size - KS
1798 k_size = gtu.product(k_shape)
1799 if not depthwise:
1800 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001801
Jeremy Johnson0c716862023-04-13 17:18:19 +01001802 if not testGen.args.level8k:
1803 # Generate comprehensive argument lists
1804 # - except for named errors, which use specific invalid value(s)
1805 if error_name == ErrorIf.PadSmallerZero:
1806 p_vals = [testGen.rng.choice(range(-5, 0))]
1807 else:
1808 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1809 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1810 if error_name == ErrorIf.StrideSmallerOne:
1811 # Can't use stride=0, as it is used to derive output shape, as a divisor
1812 s_vals = [testGen.rng.choice(range(-5, 0))]
1813 else:
1814 # Stride must be greater than 1 to force non-integer error
1815 startStride = (
1816 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001817 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001818 s_vals = [
1819 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1820 ]
1821 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1822 if error_name == ErrorIf.DilationSmallerOne:
1823 d_vals = [testGen.rng.choice(range(-5, 1))]
1824 else:
1825 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1826 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001827
Jeremy Johnson0c716862023-04-13 17:18:19 +01001828 if not error_name and testGen.args.oversize:
1829 # add some oversize argument values
1830 if max(ifm_shape) < 64:
1831 bigPadding = 9
1832 paddings.update(
1833 {
1834 x
1835 for x in itertools.product(
1836 *([[0, bigPadding]] * (k_rank * 2))
1837 )
1838 }
1839 )
1840 bigStride = 8
1841 strides.update(
1842 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1843 )
1844 bigDilation = 7
1845 dilations.update(
1846 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1847 )
1848 max_dim_size = None
1849
1850 # There are too many parameter combinations, so generate them sparsely,
1851 # very sparse for negative tests
1852 sparsity_factor = 2 if error_name else 120
1853 sparsity = TosaArgGen._calculate_sparsity(
1854 len(paddings) * len(strides) * len(dilations), sparsity_factor
1855 )
1856 else:
1857 # Only test 8k levels boundaries
1858 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1859 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1860 bigPadding = bigKernel
1861
1862 dilation_shape = [1] * k_rank
1863 pad_shape = [0] * k_rank * 2
1864 if conv3d:
1865 # Small stride apart from for big kernel (see below) to keep
1866 # tensor size/calculation small
1867 stride_shape = [1] * k_rank
1868 for idx in range(k_rank):
1869 pad_offset = idx * 2
1870 if k_shape[idx] == bigKernel:
1871 # Padding shape needs to account for tensor shape
1872 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1873 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1874 # Big stride to reduce output size
1875 stride_shape[idx] = bigKernel
1876 else:
1877 # Account for kernel size
1878 pad_shape[pad_offset] = k_shape[idx] - 1
1879 else:
1880 # Always have a large stride with extra padding and dilation to keep
1881 # tensor calculation reasonable
1882 stride_shape = [bigKernel] * k_rank
1883 for idx in range(k_rank):
1884 # Dilation shape must account for kernel size
1885 dilation_shape[idx] = bigKernel // k_shape[idx]
1886 # Padding shape needs to accommodate tensor/kernel & dilation
1887 pad_offset = idx * 2
1888 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1889 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1890
1891 strides = {tuple(stride_shape)}
1892 dilations = {tuple(dilation_shape)}
1893 paddings = {tuple(pad_shape)}
1894 # Create a limit for the output dimensions size
1895 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1896
1897 # Currently allow all combinations that are reasonable size
1898 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001899
1900 n = 0
1901 for s in sorted(list(strides)):
1902 for p in sorted(list(paddings)):
1903 for d in sorted(list(dilations)):
1904 if (
1905 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001906 # the padded shape must exceed the dilation * kernel to get a positive
1907 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001908 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
1909 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01001910 and (
1911 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01001912 or (
1913 (ifm_shape[3] - 1 + p[4] + p[5])
1914 > d[2] * (k_shape[2] - 1)
1915 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01001916 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001917 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001918 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01001919 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001920 for index in range(k_rank):
1921 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001922 partial = (
1923 ifm_shape[index + 1]
1924 - 1
1925 + p[pad_offset]
1926 + p[pad_offset + 1]
1927 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001928 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001929 remainders.append(partial % s[index])
1930 outputs.append((partial // s[index]) + 1)
1931
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001932 if (
1933 # the parameters must produce integer exact output
1934 error_name != ErrorIf.ConvOutputShapeNonInteger
1935 and max(remainders) == 0
1936 ) or (
1937 error_name == ErrorIf.ConvOutputShapeNonInteger
1938 and max(remainders) > 0
1939 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001940 if (
1941 max_dim_size is not None
1942 and max(outputs) >= max_dim_size
1943 ):
1944 # Test will consume too much memory - skip it
1945 continue
1946
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001947 # Compliance - number of dot product calculations
1948 if depthwise:
1949 # TODO - add support
1950 dots = 0
1951 else:
1952 dots = gtu.product(
1953 (ifm_shape[0], *outputs, filter_shape[0])
1954 )
1955 args_dict = {
1956 "acc_type": accum_dtype,
1957 "stride": s,
1958 "pad": p,
1959 "dilation": d,
1960 "kernel": k_shape,
1961 "ks": k_size,
1962 "dot_products": dots,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001963 "shape": ifm_shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001964 }
1965
Jeremy Johnson0c716862023-04-13 17:18:19 +01001966 # Support for larger values than 9 needs different delimiter
1967 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001968 arg_list.append(
1969 (
James Ward8b390432022-08-12 20:48:56 +01001970 "acc{}_st{}_pad{}_dilat{}".format(
1971 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001972 delim.join([str(x) for x in s]),
1973 delim.join([str(x) for x in p]),
1974 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001975 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001976 args_dict,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001977 )
1978 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001979 n += 1
1980
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001981 arg_list = TosaArgGen._add_data_generators(
1982 testGen,
1983 opName,
1984 dtypes[0],
1985 arg_list,
1986 error_name,
1987 )
1988 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001989 return arg_list
1990
1991 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001992 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1993
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001994 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001995 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01001996
1997 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001998 accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01001999 elif error_name == ErrorIf.WrongInputType:
2000 # Pick some potentially correct output dtype if input type is incorrect
2001 accum_dtype = DType.INT32
2002 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002003 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002004
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002005 # Set up compliance info
2006 args_dict = {
2007 "acc_type": accum_dtype,
2008 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2009 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2010 "shape": shapeList[0],
2011 }
2012
2013 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2014
2015 arg_list = TosaArgGen._add_data_generators(
2016 testGen,
2017 opName,
2018 input_dtype,
2019 arg_list,
2020 error_name,
2021 )
2022 # Return list of tuples: (arg_str, args_dict)
2023 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002024
2025 @staticmethod
2026 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
2027 # Get valid accumulate type(s)
2028 if dtype == DType.INT8:
2029 accum_dtypes = [DType.INT32]
2030 elif dtype == DType.INT16:
2031 accum_dtypes = [DType.INT48]
2032 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002033 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002034 elif dtype == DType.BF16:
2035 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002036 elif dtype == DType.FP32:
2037 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01002038 elif error_name is None:
2039 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2040
2041 if error_name == ErrorIf.WrongOutputType:
2042 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson1271c442023-09-05 11:39:26 +01002043 accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002044 elif error_name == ErrorIf.WrongInputType:
2045 # Pick some potentially correct output dtype if input type is incorrect
2046 accum_dtypes = [DType.INT32]
2047
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002048 # Set up compliance info
2049 args_dict = {
2050 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2051 # Set dot_products = N*H*W
2052 "dot_products": gtu.product(
2053 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2054 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002055 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002056 }
2057
2058 # Create arg tuple of string and dict
2059 arg_list = []
2060 for a in accum_dtypes:
2061 d = args_dict.copy()
2062 d["acc_type"] = a
2063 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002064
2065 arg_list = TosaArgGen._add_data_generators(
2066 testGen,
2067 opName,
2068 dtype,
2069 arg_list,
2070 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002071 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002072 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002073 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002074
2075 @staticmethod
2076 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002077 arg_list = []
2078
Jeremy Johnson0c716862023-04-13 17:18:19 +01002079 if testGen.args.level8k and error_name is not None:
2080 # Don't produce negative large tests
2081 return arg_list
2082
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002083 ifm_shape = shapeList[0]
2084 filter_shape = shapeList[1]
2085
Jeremy Johnson1271c442023-09-05 11:39:26 +01002086 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002087
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002088 # Must be rank 4
2089 if error_name != ErrorIf.WrongRank:
2090 assert len(ifm_shape) == 4
2091 assert len(filter_shape) == 4
2092
Jeremy Johnson0c716862023-04-13 17:18:19 +01002093 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002094
Jeremy Johnson0c716862023-04-13 17:18:19 +01002095 if not testGen.args.level8k:
2096 # Generate comprehensive argument lists
2097 # - except for named errors, which use specific invalid value(s)
2098 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2099 if error_name == ErrorIf.PadLargerEqualKernel:
2100 max_filter_size = -max(k_shape[0], k_shape[1])
2101 p_vals = [
2102 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
2103 ]
2104 else:
2105 p_vals = [
2106 x
2107 for x in range(
2108 smallest_padding_size, testGen.args.max_conv_padding + 1
2109 )
2110 ]
2111 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2112 if error_name == ErrorIf.StrideSmallerOne:
2113 # Can't use stride=0, as it is used to derive output shape, as a divisor
2114 s_vals = [testGen.rng.choice(range(-5, 0))]
2115 else:
2116 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2117 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002118
Jeremy Johnson0c716862023-04-13 17:18:19 +01002119 if not error_name and testGen.args.oversize:
2120 # add some oversize argument values
2121 if max(ifm_shape) < 64:
2122 bigPadding = 9
2123 paddings.update(
2124 {
2125 x
2126 for x in itertools.product(
2127 *([[smallest_padding_size, bigPadding]] * 4)
2128 )
2129 }
2130 )
2131 bigStride = 8
2132 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2133
2134 # There are too many parameter combinations, so generate them sparsely,
2135 # very sparse for negative tests
2136 sparsity_factor = 2 if error_name else 10
2137 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2138 # If there are only a small number of tests, just select them all
2139 if sparsity < 13:
2140 sparsity = 1
2141 # To get a variety of parameter combinations sparsity should not be a
2142 # multiple of 2, 3 or 5
2143 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2144 sparsity += 1
2145 else:
2146 # Only test 8k levels boundaries
2147 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2148 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2149 bigPadding = bigKernel
2150
2151 pad_shape = [0] * (len(k_shape) * 2)
2152 stride_shape = [1] * len(k_shape)
2153 # The point at which input dimension combined with the stride will
2154 # create large output sizes!
2155 LARGE_SIZE = 2
2156 for idx in range(len(k_shape)):
2157 pad_offset = idx * 2
2158 if k_shape[idx] == bigKernel:
2159 # Set large stride
2160 stride_shape[idx] = bigKernel
2161 # Use negative output padding to reduce shape size
2162 pad_shape[pad_offset] = -(bigPadding - 1)
2163 if ifm_shape[idx + 1] > LARGE_SIZE:
2164 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2165 else:
2166 # The other dimension should be the bigKernel
2167 alt_idx = 1 - idx
2168 if (
2169 k_shape[alt_idx] == bigKernel
2170 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2171 ):
2172 # As the input is small, the large stride won't
2173 # affect the output so we can add some padding
2174 pad_shape[pad_offset + 1] = bigPadding
2175
2176 strides = {tuple(stride_shape)}
2177 paddings = {tuple(pad_shape)}
2178
2179 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002180 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002181
2182 n = 0
2183 for s in sorted(list(strides)):
2184 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07002185 if n % sparsity == 0:
2186 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002187 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2188 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07002189 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002190
2191 # Support for larger values than 9 needs different delimiter
2192 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07002193 arg_list.append(
2194 (
James Ward8b390432022-08-12 20:48:56 +01002195 "acc{}_st{}_pad{}_os{}".format(
2196 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002197 delim.join([str(x) for x in s]),
2198 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07002199 "x".join([str(x) for x in os]),
2200 ),
James Ward8b390432022-08-12 20:48:56 +01002201 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002202 )
TatWai Chong24594f52022-06-08 00:48:04 -07002203 )
2204 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002205
2206 return arg_list
2207
2208 @staticmethod
2209 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002210 rank = len(shapeList[0])
2211
2212 # Exhaustively test combinations of padding on each side of each dimension
2213 # - the range of padding values is defined by pad_min and pad_max
2214 # - for padding >9, the name format needs to be more distinctive
2215 pad_min, pad_max = 0, 1
2216 pad_values = [x for x in range(pad_min, pad_max + 1)]
2217 if error_name == ErrorIf.PadSmallerZero:
2218 pad_values = [x for x in range(-2, 0)]
2219 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2220 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2221
2222 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
2223 pad_const_int = testGen.getRandNumberDType(dtype)
2224 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00002225 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002226 pad_const_int = 0
2227 pad_const_fp = testGen.getRandNumberDType(dtype)
2228 else:
2229 return []
2230
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002231 list_shape_pad_values = list(shape_pad_values)
2232 # If we are producing tests for rank 6 or greater use sparsity
2233 if len(list_shape_pad_values) > 1024:
2234 sparsity_factor = 2 if error_name else 120
2235 sparsity = TosaArgGen._calculate_sparsity(
2236 len(list_shape_pad_values), sparsity_factor
2237 )
2238 else:
2239 sparsity = 1
2240
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002241 # Build arg list
2242 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002243 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002244 paddings = list(paddings)
2245 args_valid = True
2246
2247 if error_name == ErrorIf.PadSmallerZero:
2248 # Prevent negative output shapes while ensuring still testing for negative padding
2249 for i in range(rank):
2250 dim_after_padding = (
2251 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2252 )
2253 if dim_after_padding < 1:
2254 paddings[i] = (0, 0)
2255 if all([p > -1 for p in paddings[i]]):
2256 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002257 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01002258 name = "pad"
2259 for r in range(rank):
2260 before, after = paddings[r]
2261 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002262 args_dict = {
2263 "pad": np.array(paddings),
2264 "pad_const_int": pad_const_int,
2265 "pad_const_fp": pad_const_fp,
2266 }
2267 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002268
2269 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
2270 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002271
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002272 arg_list = TosaArgGen._add_data_generators(
2273 testGen,
2274 opName,
2275 dtype,
2276 arg_list,
2277 error_name,
2278 )
2279
2280 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002281 return arg_list
2282
2283 @staticmethod
2284 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
2285 arg_list = []
2286
2287 shape = shapeList[0]
2288 if error_name != ErrorIf.WrongRank:
2289 assert len(shape) == 4
2290
Jeremy Johnson0c716862023-04-13 17:18:19 +01002291 test_level8k = testGen.args.level8k and error_name is None
2292
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002293 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002294 startKernel = 2
2295 startPad = 0
2296 if not test_level8k:
2297 # Generate comprehensive argument lists
2298 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2299 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2300 # Stride must be greater than 1 to force non-integer error
2301 s_vals = [
2302 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2303 ]
2304 strides = {x for x in itertools.product(*([s_vals] * 2))}
2305 k_vals = [
2306 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2307 ]
2308 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2309 max_dim_size = None
2310 else:
2311 # Only test 8k levels
2312 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2313 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2314 strides = {(1, bigStride), (bigStride, 4)}
2315 kernels = {(1, bigKernel), (bigKernel, 3)}
2316 paddings = set()
2317 for s in sorted(list(strides)):
2318 for k in sorted(list(kernels)):
2319 padding = []
2320 for idx in range(len(k)):
2321 total_padding = s[idx] - shape[idx + 1] + k[idx]
2322 while total_padding < 0:
2323 # Must meet: shape + padding > kernel
2324 total_padding += s[idx]
2325 if total_padding < k[idx]:
2326 padding.extend([0, total_padding])
2327 else:
2328 # Note this may produce padding >= k[idx] which is not
2329 # allowed - but will be ignored in the creation loop below
2330 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2331 paddings.add(tuple(padding))
2332 # Create a limit for the output dimensions size
2333 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002334
James Ward8b390432022-08-12 20:48:56 +01002335 if opName == "max_pool2d":
2336 accum_dtypes = [None] # max_pool has no accumulate dtype
2337 elif dtype == DType.INT8 or dtype == DType.INT16:
2338 accum_dtypes = [DType.INT32]
2339 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002340 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002341 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002342 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01002343 elif error_name is None:
2344 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2345 else:
2346 # Set to something for the ErrorIf case which has
2347 # incorrect input data-type
2348 accum_dtypes = [DType.INT32]
2349
Jeremy Johnson0c716862023-04-13 17:18:19 +01002350 if not test_level8k:
2351 if testGen.args.oversize:
2352 # add some oversize argument values
2353 bigStride = 7
2354 bigKernel = 9
2355 strides.update(
2356 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002357 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002358 kernels.update(
2359 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2360 )
2361 if max(shape) < 64:
2362 # padding must be less than the kernel size
2363 bigPadding = bigKernel - 1
2364 paddings.update(
2365 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2366 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002367
Jeremy Johnson0c716862023-04-13 17:18:19 +01002368 # There are too many parameter combinations, so generate them sparsely,
2369 # very sparse for negative tests
2370 sparsity_factor = 2 if error_name else 500
2371 sparsity = (
2372 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2373 )
2374 else:
2375 # We have already limited test output combinations for 8k tests
2376 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002377
James Ward8b390432022-08-12 20:48:56 +01002378 arg_str = (
2379 "acc{}_st{}_kern{}_pad{}"
2380 if accum_dtypes[0] is not None
2381 else "st{}_kern{}_pad{}"
2382 )
2383
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002384 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002385 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002386 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002387
2388 # Support for larger values than 9 needs different delimiter
2389 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002390 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002391 delim.join([str(x) for x in stride]),
2392 delim.join([str(x) for x in kern]),
2393 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002394 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002395 args_dict = {
2396 "stride": stride,
2397 "pad": pad,
2398 "kernel": kern,
2399 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002400 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002401 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2402 }
James Ward8b390432022-08-12 20:48:56 +01002403
2404 if accum is not None:
2405 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002406 args_dict["acc_type"] = accum
2407 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002408
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002409 n = 0
James Ward8b390432022-08-12 20:48:56 +01002410 for a in accum_dtypes:
2411 for s in sorted(list(strides)):
2412 for p in sorted(list(paddings)):
2413 for k in sorted(list(kernels)):
2414 if error_name in [
2415 ErrorIf.StrideSmallerOne,
2416 ErrorIf.KernelSmallerOne,
2417 ErrorIf.PadSmallerZero,
2418 ErrorIf.PadLargerEqualKernel,
2419 ]:
2420 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
2421 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002422 )
James Ward8b390432022-08-12 20:48:56 +01002423 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002424 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002425 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002426 )
James Ward8b390432022-08-12 20:48:56 +01002427 elif (
2428 n % sparsity == 0
2429 # padding must not exceed the kernel size
2430 and p[0] < k[0]
2431 and p[1] < k[0]
2432 and p[2] < k[1]
2433 and p[3] < k[1]
2434 # the padded shape must exceed the kernel size
2435 and (shape[1] + p[0] + p[1]) > k[0]
2436 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002437 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002438 partial_h = shape[1] + p[0] + p[1] - k[0]
2439 partial_w = shape[2] + p[2] + p[3] - k[1]
2440 remainder_h = partial_h % s[0]
2441 remainder_w = partial_w % s[1]
2442 output_h = partial_h // s[0] + 1
2443 output_w = partial_w // s[1] + 1
2444 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01002445 if (
2446 # the parameters must produce integer exact output
2447 error_name != ErrorIf.PoolingOutputShapeNonInteger
2448 and remainder_h == 0
2449 and remainder_w == 0
2450 ) or (
2451 error_name == ErrorIf.PoolingOutputShapeNonInteger
2452 and (remainder_h != 0 or remainder_w != 0)
2453 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002454 if (
2455 max_dim_size is not None
2456 and max(output_h, output_w) > max_dim_size
2457 ):
2458 # Test will consume too much memory - skip it
2459 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002460 # Dot products = N*OH*OW*C
2461 dp = gtu.product(
2462 (shape[0], output_h, output_w, shape[3])
2463 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002464 arg_list.append(
2465 get_arg_list_element(a, s, p, k, dp, shape)
2466 )
James Ward8b390432022-08-12 20:48:56 +01002467 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002468
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002469 # Now add data generator types
2470 arg_list = TosaArgGen._add_data_generators(
2471 testGen,
2472 opName,
2473 dtype,
2474 arg_list,
2475 error_name,
2476 )
2477
2478 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002479 return arg_list
2480
2481 @staticmethod
2482 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
2483 arg_list = []
2484
2485 # Enumerate the output types here
2486 if error_name == ErrorIf.WrongOutputType:
2487 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
2488 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002489 dtypeList = [
2490 DType.BOOL,
2491 DType.INT16,
2492 DType.INT32,
2493 DType.FP16,
2494 DType.BF16,
2495 DType.FP32,
2496 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002497 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002498 dtypeList = [
2499 DType.BOOL,
2500 DType.INT8,
2501 DType.INT32,
2502 DType.FP16,
2503 DType.BF16,
2504 DType.FP32,
2505 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002506 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002507 dtypeList = [
2508 DType.BOOL,
2509 DType.INT8,
2510 DType.INT16,
2511 DType.FP16,
2512 DType.BF16,
2513 DType.FP32,
2514 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002515 elif inDtype == DType.BOOL:
2516 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002517 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00002518 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002519 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00002520 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002521 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00002522 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002523 elif error_name == ErrorIf.WrongInputType:
2524 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002525 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002526 else:
2527 raise Exception("Unexpected input dtype: {}".format(inDtype))
2528
2529 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002530 arg_list.append(
2531 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2532 )
2533
2534 # Now add data generator types
2535 arg_list = TosaArgGen._add_data_generators(
2536 testGen,
2537 opName,
2538 dtype,
2539 arg_list,
2540 error_name,
2541 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002542
2543 return arg_list
2544
2545 @staticmethod
2546 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
2547 arg_list = []
2548
2549 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002550 for outDtype in [
2551 DType.UINT8,
2552 DType.INT8,
2553 DType.INT16,
2554 DType.INT32,
2555 DType.UINT16,
2556 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002557 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002558 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002559 and error_name == ErrorIf.OutputZeroPointNotZero
2560 ):
2561 continue
2562 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002563 outDtype != DType.UINT16
2564 and error_name == ErrorIf.U16OutputZeroPointNotValid
2565 ) or (
2566 inDtype != DType.UINT16
2567 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002568 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002569 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002570 continue
2571 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002572 inDtype == DType.UINT8
2573 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002574 and error_name != ErrorIf.WrongOutputType
2575 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002576 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2577 continue
2578 if (
2579 inDtype not in [DType.INT8, DType.INT16]
2580 and outDtype == DType.UINT8
2581 and error_name != ErrorIf.WrongOutputType
2582 ):
2583 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2584 continue
2585 if (
2586 inDtype == DType.UINT16
2587 and outDtype != DType.INT16
2588 and error_name != ErrorIf.WrongOutputType
2589 ):
2590 # The only output dtype for UINT16 is INT16, skip all others
2591 continue
2592 if (
2593 inDtype != DType.INT16
2594 and outDtype == DType.UINT16
2595 and error_name != ErrorIf.WrongOutputType
2596 ):
2597 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002598 continue
2599 if (
2600 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002601 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002602 ):
2603 continue
2604
2605 for scale32 in [False, True]:
2606 if error_name == ErrorIf.ScaleTrue and not scale32:
2607 continue
2608 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2609 continue
2610 for double_round in [False, True]:
2611 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2612 continue
2613 for per_channel in [False, True]:
2614
2615 if (
2616 inDtype == DType.INT48
2617 and scale32
2618 and error_name != ErrorIf.ScaleTrue
2619 ):
2620 # Illegal condition. Must be scale32=False
2621 continue
2622 if (
2623 double_round
2624 and not scale32
2625 and error_name != ErrorIf.ScaleNotTrue
2626 ):
2627 # Illegal condition. ERROR_IF(!scale32 && double_round)
2628 continue
2629
2630 arg_list.append(
2631 (
2632 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002633 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002634 int(scale32),
2635 int(double_round),
2636 int(per_channel),
2637 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002638 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002639 )
2640 )
2641
2642 return arg_list
2643
2644 @staticmethod
2645 def agMul(testGen, opName, shapeList, dtype, error_name=None):
2646 arg_list = []
2647
2648 if dtype is DType.INT32:
2649 for p in range(testGen.args.num_rand_permutations):
2650
2651 shift = testGen.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002652 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002653 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002654 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002655
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002656 arg_list = TosaArgGen._add_data_generators(
2657 testGen,
2658 opName,
2659 dtype,
2660 arg_list,
2661 error_name,
2662 )
2663 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002664 return arg_list
2665
2666 @staticmethod
2667 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
2668 arg_list = []
2669
2670 arg_list.append(("roundTrue", [True]))
2671 arg_list.append(("roundFalse", [False]))
2672
2673 return arg_list
2674
Luke Hutton57287132023-02-06 14:54:18 +00002675 @staticmethod
2676 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2677 arg_list = []
2678
2679 arg_list.append(("inverseTrue", [True]))
2680 arg_list.append(("inverseFalse", [False]))
2681
2682 return arg_list
2683
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002684 # Helper function for reshape. Gets some factors of a larger number.
2685 @staticmethod
2686 def getFactors(val, start=1):
2687 factors = []
2688
2689 for i in range(start, int(np.sqrt(val)) + 1):
2690 if (val % i) == 0:
2691 factors.append(i)
2692
2693 return factors
2694
2695 @staticmethod
2696 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
2697 arg_list = []
2698
2699 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002700 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002701 factors = TosaArgGen.getFactors(totalElements)
2702
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002703 # Find new shapes up to the number of permutations asked for
2704 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002705 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002706 # Rank from 1 to TOSA_TENSOR_MAX_RANK
2707 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002708 if len(factors) < newRank:
2709 continue
2710
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002711 # escape_counter limits the generation of new shapes to a reasonable time
2712 for escape_counter in range(100):
2713
2714 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002715 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002716 remainingElements = totalElements
2717 shuffledFactors = testGen.rng.permutation(factors)
2718 for i in range(1, newRank):
2719 # pick rank-1 factors
2720 newShape.append(shuffledFactors[0])
2721 remainingElements = remainingElements // shuffledFactors[0]
2722 shuffledFactors = testGen.rng.permutation(
2723 TosaArgGen.getFactors(remainingElements)
2724 )
2725 newShape.append(remainingElements)
2726
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002727 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002728 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002729 for name, args_dict in arg_list:
2730 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002731 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002732 break
2733
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002734 if not duplicate:
2735 outShape = "x".join([str(x) for x in newShape])
2736 arg_list.append(
2737 (
2738 "perm{}_rank{}_out{}".format(p, newRank, outShape),
2739 {"new_shape": newShape},
2740 )
2741 )
2742 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002743 break
2744
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002745 # Now add data generator types
2746 arg_list = TosaArgGen._add_data_generators(
2747 testGen,
2748 opName,
2749 dtype,
2750 arg_list,
2751 error_name,
2752 )
2753
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002754 return arg_list
2755
2756 @staticmethod
2757 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
2758 arg_list = []
2759
2760 ifm_shape = shapeList[0]
2761
2762 if error_name == ErrorIf.IndexOutsideBounds:
2763 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
2764 incorrect_small_index = range(-len(ifm_shape), 0)
2765 permutations = [p for p in itertools.permutations(incorrect_large_index)]
2766 permutations.extend(
2767 [p for p in itertools.permutations(incorrect_small_index)]
2768 )
2769 elif error_name == ErrorIf.IndexUsedTwice:
2770 # Create list with a duplicated index
2771 perm_range = list(range(len(ifm_shape)))
2772 index_choice = testGen.rng.choice(range(len(perm_range)))
2773 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
2774 permutations = [p for p in itertools.permutations(perm_range)]
2775
2776 else:
2777 # Get all permutations
2778 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
2779
2780 # Limit to possible permutations from shape dimension or argument setting
2781 limit = min(len(permutations), testGen.args.num_rand_permutations)
2782
2783 # Get random permutation generator that uses all permutations
2784 random_permutations = testGen.rng.permutation(permutations)
2785
2786 # Create list of required amount of permutations
2787 arg_list = [
2788 ("perm{}".format(p), [random_permutations[p].tolist()])
2789 for p in range(limit)
2790 ]
2791 return arg_list
2792
2793 @staticmethod
2794 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
2795 arg_list = []
2796
2797 ifm_shape = shapeList[0]
2798 rank = len(ifm_shape)
2799
2800 for p in range(testGen.args.num_rand_permutations):
2801 start = []
2802 size = []
2803
2804 valid = True
2805
2806 for i in range(rank):
2807 if ifm_shape[i] > 1:
2808 start.append(testGen.randInt(0, ifm_shape[i]))
2809 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
2810
2811 # Invalid slice size?
2812 if size[i] == 0:
2813 valid = False
2814 else:
2815 start.append(0)
2816 size.append(1)
2817
2818 if valid:
2819 # If ERROR_IF test required then incorrect start, size will be returned
2820 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
2821 testGen, error_name, ifm_shape, start, size
2822 )
2823 arg_list.append(("perm{}".format(p), [start, size]))
2824 return arg_list
2825
2826 @staticmethod
2827 def agTile(testGen, opName, shapeList, dtype, error_name=None):
2828 arg_list = []
2829
2830 ifm_shape = shapeList[0]
2831 rank = len(ifm_shape)
2832
2833 for p in range(testGen.args.num_rand_permutations):
2834
2835 # Pick a few random, but small multiple values
2836 # because otherwise this has a tendency to generate
2837 # enormous tensors
2838 multiples = []
2839 for i in range(rank):
2840 if ifm_shape[i] > 1000:
2841 # Multiple of 1 if ifm_shape dimension is large to reduce
2842 # tensor size
2843 multiples.append(1)
2844 elif max(ifm_shape) > 1000:
2845 multiples.append(2)
2846 else:
2847 multiples.append(testGen.randInt(1, 4))
2848 arg_list.append(("perm{}".format(p), [multiples]))
2849
2850 return arg_list
2851
2852 @staticmethod
2853 def agResize(testGen, opName, shapeList, dtype, error_name=None):
2854 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002855 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002856
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002857 def get_aspect_ratio_resize_params():
2858 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
2859 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
2860 invert = testGen.rng.choice((False, True))
2861 letterbox = testGen.rng.choice((False, True))
2862
2863 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
2864 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
2865 scale_y_d = scale_x_d = 1
2866 offset_x = offset_y = 0
2867
2868 if letterbox:
2869 max_border = scale_y_n
2870 border_y = testGen.randInt(low=0, high=max_border)
2871 border_x = 0
2872 else:
2873 # Pillarboxing
2874 border_y = 0
2875 max_border = scale_x_n
2876 border_x = testGen.randInt(low=0, high=max_border)
2877
2878 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2879 offset = (offset_y, offset_x)
2880 border = (border_y, border_x)
2881
2882 return scale, offset, border
2883
2884 def get_upscale_downscale_params():
2885 valid_params = False
2886 while not valid_params:
2887 upscale = testGen.rng.choice((False, True))
2888
2889 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
2890 origin_sampling = testGen.rng.choice((False, True))
2891
2892 if upscale:
2893 shift = testGen.randInt(low=1, high=4)
2894 scale_x_d = scale_y_d = 1
2895 scale_x_n = scale_y_n = (
2896 1 << shift if origin_sampling else 2 << shift
2897 )
2898 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
2899 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
2900 else:
2901 scale_x_n = 1
2902 scale_y_n = 1
2903
2904 # Return list of valid scale_*_d values (max value 4) given input dim shape
2905 def get_valid_denom(ifm_dim):
2906 return [x for x in range(1, 5) if ifm_dim % x == 1]
2907
2908 # Generate list of valid downscale values and choose one randomly
2909 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
2910 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
2911
2912 if not valid_scale_y_ds and not valid_scale_x_ds:
2913 # Bad parameters, skip
2914 continue
2915
2916 if not valid_scale_y_ds:
2917 scale_y_d = 1
2918 else:
2919 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
2920
2921 if not valid_scale_x_ds:
2922 scale_x_d = 1
2923 else:
2924 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
2925
2926 border_x = border_y = 0
2927 offset_y = testGen.randInt(0, 16 * scale_y_n)
2928 offset_x = testGen.randInt(0, 16 * scale_x_n)
2929 valid_params = True
2930
2931 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2932 offset = (offset_y, offset_x)
2933 border = (border_y, border_x)
2934 return scale, offset, border
2935
2936 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002937 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
2938 scale = scale_n / scale_d
2939 if scale > max_scale:
2940 factor = scale / max_scale
2941 new_scale_d = math.ceil(scale_d * factor)
2942 assert scale_n / new_scale_d <= max_scale
2943 scale_d = new_scale_d
2944 return scale_d
2945
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002946 # Scale
2947 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
2948 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
2949
2950 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
2951 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
2952
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002953 scale_y_d = fix_scale_to_max_scale(
2954 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2955 )
2956 scale_x_d = fix_scale_to_max_scale(
2957 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2958 )
2959
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002960 # Offsets and border within the scale
2961 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
2962 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
2963 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
2964 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
2965
2966 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2967 offset = (offset_y, offset_x)
2968 border = (border_y, border_x)
2969 return scale, offset, border
2970
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002971 def get_level_8k_params():
2972 # Create 64x scale - 64/1 to 2048/32
2973 scale_d = testGen.randInt(
2974 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
2975 )
2976 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
2977 # Create half to fifth scaling
2978 scale_d_alt = testGen.randInt(low=2, high=6)
2979 scale_n_alt = 1
2980 switch = testGen.rng.choice((False, True))
2981 if switch:
2982 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
2983 else:
2984 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
2985
2986 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
2987 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
2988 offset = (offset_y, offset_x)
2989 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
2990 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
2991 border = (border_y, border_x)
2992 return scale, offset, border
2993
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002994 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002995 # Exclude illegal {mode, type} configurations. Pick legal output types
2996 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
2997 outputDTypeList = [DType.INT8]
2998 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
2999 outputDTypeList = [DType.INT16]
3000 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3001 outputDTypeList = [DType.INT32]
3002 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3003 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003004 elif dtype == DType.FP16:
3005 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003006 elif dtype == DType.BF16:
3007 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003008 elif dtype == DType.FP32:
3009 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003010 elif error_name == ErrorIf.WrongInputType:
3011 # If an incorrect input type is used then we set a 'correct'
3012 # output type to avoid other errors
3013 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3014 else:
3015 continue
3016
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003017 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3018
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003019 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003020 perm = 0
3021 while perm < testGen.args.num_rand_permutations:
3022 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003023 if not testGen.args.level8k:
3024 _rnd_param_fn = testGen.rng.choice(
3025 (
3026 get_rand_params,
3027 get_upscale_downscale_params,
3028 get_aspect_ratio_resize_params,
3029 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003030 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003031 scale, offset, border = _rnd_param_fn()
3032 else:
3033 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003034
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003035 # Expand params for bounds-checking
3036 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3037 (offset_y, offset_x) = offset
3038 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003039
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003040 # Make sure output dimensions OH and OW are integers
3041 partial_output_y = (
3042 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3043 )
3044 partial_output_x = (
3045 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3046 )
3047 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003048 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003049 if (
3050 partial_output_y % scale_y_d == 0
3051 and partial_output_x % scale_x_d == 0
3052 ):
3053 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003054 if perm > 0:
3055 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003056 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003057 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003058 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003059 while partial_output_y % scale_y_d != 0:
3060 scale_y_d -= 1
3061 while partial_output_x % scale_x_d != 0:
3062 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003063 # Make sure we are still within max scaling
3064 if (
3065 scale_y_n / scale_y_d
3066 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3067 scale_x_n / scale_x_d
3068 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3069 # Skip the test as it is using too large a scaling factor
3070 if perm > 0:
3071 perm += 1
3072 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003073
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003074 output_y = partial_output_y // scale_y_d + 1
3075 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003076
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003077 if (
3078 output_y >= testGen.args.max_resize_output_dim
3079 or output_x >= testGen.args.max_resize_output_dim
3080 ) and error_name is None:
3081 # Skip positive test if output dim will be too high
3082 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003083 if not testGen.args.level8k or perm > 0:
3084 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003085 continue
3086
3087 if (
3088 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003089 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003090 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003091 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003092 ):
3093 # Output dimensions out of scope
3094 if error_name is not None and perm > 0:
3095 # As long as we have one ERROR_IF test, don't worry
3096 # about creating all the other permutations
3097 perm += 1
3098 continue
3099
3100 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3101 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003102 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003103 and output_y - scale_y_d < 1
3104 )
3105 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003106 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003107 and output_x - scale_x_d < 1
3108 )
3109 ):
3110 # Can't create a negative test with these params as it
3111 # will create invalid output size
3112 if perm > 0:
3113 perm += 1
3114 continue
3115
3116 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3117 offset = [offset_y, offset_x]
3118 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003119
3120 # Common for all data types
3121 if error_name is not None:
3122 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003123 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003124 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003125 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003126 outputDTypeNew,
3127 ) = TosaErrorIfArgGen.eiResizeErrorIf(
3128 testGen,
3129 error_name,
3130 mode,
3131 dtype,
3132 shapeList,
3133 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003134 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003135 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003136 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003137 )
3138 else:
3139 outputDTypeNew = outputDType
3140
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003141 arg_to_append = (
3142 arg_str.format(
3143 "N" if mode == ResizeMode.NEAREST else "B",
3144 testGen.typeStr(outputDTypeNew),
3145 scale[0],
3146 scale[1],
3147 scale[2],
3148 scale[3],
3149 offset[0],
3150 offset[1],
3151 border[0],
3152 border[1],
3153 ),
3154 [
3155 mode,
3156 scale,
3157 offset,
3158 border,
3159 dtype,
3160 outputDTypeNew,
3161 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003162 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003163 if arg_to_append in arg_list:
3164 # Skip already generated test params
3165 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003166
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003167 # Valid permutation
3168 perm += 1
3169 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003170 return arg_list
3171
3172 @staticmethod
3173 def agTable(testGen, opName, shapeList, dtype, error_name=None):
3174 arg_list = []
3175
3176 if dtype == DType.INT8:
3177 table = np.int32(
3178 testGen.rng.integers(low=-128, high=128, size=[256])
3179 ).tolist()
3180 else: # INT16
3181 table = np.int32(
3182 testGen.rng.integers(low=-32768, high=32768, size=[513])
3183 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003184 # Make sure all slopes are within REQUIRE min/max 16-bit int
3185 for idx in range(len(table) - 1):
3186 slope = table[idx + 1] - table[idx]
3187 # Alter the next table entry to force the slope to be ok
3188 if slope > 32767:
3189 table[idx + 1] -= slope - 32767
3190 if slope < -32768:
3191 table[idx + 1] -= slope + 32768
3192 slope = table[idx + 1] - table[idx]
3193 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003194 arg_list.append(
3195 (
3196 "",
3197 [table],
3198 )
3199 )
3200 return arg_list
3201
3202 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
3203 # CondIf generates the condition values here.
3204 # Convert to tensors in the build function, along with the
3205 # then and else blocks
3206 arg_list = []
3207
3208 for c in [False, True]:
3209 arg_list.append(("cond{}".format(int(c)), [c]))
3210
3211 return arg_list
3212
3213 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
3214 # While loop: 0 iterations, 1, more than 1
3215 arg_list = []
3216
3217 for iter in [0, 1, 4]:
3218 arg_list.append(("iter{}".format(iter), [iter]))
3219
3220 return arg_list