blob: 1e2382279204eb82a849a8def305eafe15c7262c [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
19
20class TosaQuantGen:
21 """QuantizedInfo random generator helper functions.
22
23 Specify with 'qgen': in the operator defintion.
24 """
25
26 def __init__(self):
27 pass
28
29 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010031
32 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010033 if testGen.args.zeropoint is not None:
34 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035 return testGen.randInt(-128, 128)
36 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010037 if testGen.args.zeropoint is not None:
38 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 return testGen.randInt(0, 256)
40 elif error_name in [
41 ErrorIf.InputZeroPointNotZero,
42 ErrorIf.WeightZeroPointNotZero,
43 ErrorIf.OutputZeroPointNotZero,
44 ]:
45 zero_point = testGen.randInt(-128, 128)
46 if zero_point == 0:
47 zero_point = 1
48 return zero_point
49 return 0
50
51 @staticmethod
52 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010053 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000054 qinfo = [
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 TosaQuantGen.getZeroPoint(testGen, dtype),
57 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010058 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000059 qinfo = [
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
62 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010063 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000064 qinfo = [
65 TosaQuantGen.getZeroPoint(testGen, dtype),
66 TosaQuantGen.getZeroPoint(testGen, dtype),
67 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068 return qinfo
69
70 @staticmethod
71 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010072 if isinstance(dtype_or_dtypeList, list):
73 # a list of [input, weights, accumulator] dtypes
74 dtypeList = dtype_or_dtypeList
75 else:
76 # an int, [input, weights, accumulator] dtypes are the same
77 dtypeList = [dtype_or_dtypeList] * 3
78
79 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000080 qinfo = [
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
82 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
83 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000085 qinfo = [
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
87 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
88 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000090 qinfo = [
91 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
92 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
93 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 return qinfo
95
96 @staticmethod
97 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000099 qinfo = [
100 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
101 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
102 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000104 qinfo = [
105 TosaQuantGen.getZeroPoint(testGen, dtype),
106 TosaQuantGen.getZeroPoint(testGen, dtype),
107 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108 return qinfo
109
110 @staticmethod
111 def computeMultiplierAndShift(scaleFp, scale32):
112 # Derived from computeMultiplierAndShiftTosaScale32
113 # Provide a floating-point scaling factor and the scale32 parameter
114 # to compute the multiplier and shift
115
116 if scale32:
117 scaleBits = 31
118 else:
119 scaleBits = 15
120
121 m, shift = math.frexp(scaleFp)
122
123 if scaleFp < 0.0:
124 m = -m
125
126 multiplier = round(m * (1 << scaleBits))
127 assert multiplier <= (1 << scaleBits)
128
129 if multiplier == (1 << scaleBits):
130 multiplier = multiplier // 2
131 shift = shift + 1
132
133 shift = (-shift) + scaleBits
134 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
135 # scaleFp, scaleBits, m, multiplier, shift))
136
137 # Adjust multiplier such that shift is in allowed value range.
138 if shift == 0:
139 multiplier = multiplier // 4
140 shift = shift + 2
141 elif shift == 1:
142 multiplier = multiplier // 2
143 shift = shift + 1
144 elif shift == 63:
145 multiplier = multiplier * 2
146 shift = shift - 1
147
148 assert multiplier <= (1 << scaleBits)
149 assert shift >= 2 and shift <= 62
150
151 return multiplier, shift
152
153
154class TosaTensorGen:
155 """Tensor generators create a shape list for the placeholder and const tensor
156 data operands for the operator.
157
158 The actual random data is generated separately for each test.
159 """
160
161 def __init__(self):
162 pass
163
164 @staticmethod
165 def tgBasic(testGen, opName, rank, error_name=None):
166 pl, const = opName["operands"]
167 shape = testGen.makeShape(rank)
168
169 # Constrict the overall size of the shape when creating ERROR_IF tests
170 if error_name:
171 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
172
173 shape_list = []
174 for i in range(pl + const):
175 shape_list.append(shape.copy())
176
Luke Huttona4e48ca2023-02-22 11:53:48 +0000177 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 if error_name == ErrorIf.RankMismatch:
179 if rank == 1 and i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
181 elif i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
183
184 return shape_list
185
186 @staticmethod
187 def tgNHWC(testGen, opName, rank, error_name=None):
188 pl, const = opName["operands"]
189
190 if error_name != ErrorIf.WrongRank:
191 assert rank == 4
192
193 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000194 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name and error_name != ErrorIf.MaxDimExceeded:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000207 def tgGather(testGen, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100208 pl, const = opName["operands"]
209
210 assert pl == 2
211 assert const == 0
212 if error_name != ErrorIf.WrongRank:
213 assert rank == 3
214
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000215 values_shape = testGen.makeShape(rank)
216 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100217
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000218 N = values_shape[0]
219 W = testGen.makeDimension()
220 indices_shape = [N, W]
221
222 shape_list = [values_shape, indices_shape]
223 return shape_list
224
225 @staticmethod
226 def tgScatter(testGen, opName, rank, error_name=None):
227 pl, const = opName["operands"]
228
229 assert pl == 3
230 assert const == 0
231 if error_name != ErrorIf.WrongRank:
232 assert rank == 3
233
234 values_in_shape = testGen.makeShape(rank)
235 values_in_shape = testGen.constrictBatchSize(values_in_shape)
236
237 N = values_in_shape[0]
238 K = values_in_shape[1]
239 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100240
Jeremy Johnson194fe312023-12-07 14:17:57 +0000241 # Make sure W is not greater than K, as we can only write each output index
242 # once (having a W greater than K means that you have to repeat a K index)
243 W_min = min(testGen.args.tensor_shape_range[0], K)
244 W_max = min(testGen.args.tensor_shape_range[1], K)
245 W = testGen.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100246
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000247 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100248
249 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000250 shape_list.append(values_in_shape)
251 shape_list.append([N, W]) # indices
252 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100253
254 return shape_list
255
256 @staticmethod
257 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
258 shape = testGen.makeShape(rank)
259
260 pl, const = op["operands"]
261
262 shape_list = []
263
264 # Choose one of the inputs to broadcast
265 # Note: Simplifies OutputShaper code if we don't change first shape for errors
266 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Jerry Ge135c9552023-05-23 20:59:32 +0000267 fuzz_idx = testGen.randInt(0, rank)
268
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100269 for i in range(pl + const):
270 shape_bcast = shape.copy()
271
Jerry Ge135c9552023-05-23 20:59:32 +0000272 # To test broadcasting, the chosen fuzz index dimension should not be 1
273 if shape_bcast[fuzz_idx] == 1:
274 shape_bcast[fuzz_idx] += 1
275
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100276 # If the chosen input, pick a random index to broadcast
277 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000278 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100279 # Add one rank to the shape (or more for rank of 1)
280 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
281 shape_bcast = np.concatenate(
282 (shape_bcast, testGen.makeShape(extra_ranks))
283 )
284 if rank != 1:
285 # Either keep the extra rank, or remove it
286 new_len = testGen.rng.choice([-2, len(shape_bcast)])
287 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000288 elif error_name == ErrorIf.BroadcastShapesMismatch:
289 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100290 else:
291 shape_bcast[fuzz_idx] = 1
292
293 shape_list.append(shape_bcast)
294
295 return shape_list
296
297 @staticmethod
298 def tgConv2D(testGen, op, rank, error_name=None):
299 pl, const = op["operands"]
300
301 if error_name != ErrorIf.WrongRank:
302 assert rank == 4
303
304 # IFM dimensions are NHWC
305 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000306 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100307
308 # Constrict the overall size of the shape when creating ERROR_IF tests
309 if error_name:
310 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
311 ifm_shape, max_dim=24, max_items=10000
312 )
313
314 # Get the filter height/width from the operator parameters
315 filter_hw = op["filter"]
316
317 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000318 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100319
320 # The filter dimensions are OHWI
321 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
322
323 # The bias is OC
324 bias_shape = np.asarray([ofm_depth])
325
326 return [ifm_shape, filter_shape, bias_shape]
327
328 @staticmethod
329 def tgConv3D(testGen, op, rank, error_name=None):
330 pl, const = op["operands"]
331
332 if error_name != ErrorIf.WrongRank:
333 assert rank == 5
334
335 # IFM dimensions are NDHWC
336 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000337 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100338
339 # Constrict the overall size of the shape when creating ERROR_IF tests
340 if error_name:
341 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
342 ifm_shape, max_dim=24, max_items=10000
343 )
344
345 # Get the filter depth/height/width from the operator parameters
346 filter_dhw = op["filter"]
347
348 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000349 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100350
351 # The filter dimensions are ODHWI
352 filter_shape = np.asarray(
353 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
354 )
355
356 # The bias is OC
357 bias_shape = np.asarray([ofm_channel])
358
359 return [ifm_shape, filter_shape, bias_shape]
360
361 @staticmethod
362 def tgTransposeConv2D(testGen, op, rank, error_name=None):
363 pl, const = op["operands"]
364
365 if error_name != ErrorIf.WrongRank:
366 assert rank == 4
367
368 # IFM dimensions are NHWC
369 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000370 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100371
372 # Constrict the overall size of the shape when creating ERROR_IF tests
373 if error_name:
374 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
375 ifm_shape, max_dim=24, max_items=10000
376 )
377
378 # Get the filter height/width from the operator parameters
379 filter_hw = op["filter"]
380
381 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000382 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100383
384 # The filter dimensions are OHWI
385 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
386
387 # The bias is OC
388 bias_shape = np.asarray([ofm_depth])
389
390 return [ifm_shape, filter_shape, bias_shape]
391
392 @staticmethod
393 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
394 pl, const = op["operands"]
395
396 if error_name != ErrorIf.WrongRank:
397 assert rank == 4
398 assert pl == 1 and const == 2
399
400 # IFM dimensions are NHWC
401 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000402 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100403
404 # Constrict the overall size of the shape when creating ERROR_IF tests
405 if error_name:
406 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
407 ifm_shape, max_dim=24, max_items=10000
408 )
409
410 # Get the filter height/width from the operator parameters
411 # Filter is KH, HW, C, M
412 filter_hw = op["filter"]
413
414 # Generate a random OFM depth, but don't let it get too big because
415 # the output depth is M * C
416 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000417 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100418 ) + 1
419
420 # The filter dimensions are HWCM
421 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
422
423 # The bias is M * C
424 bias_shape = np.asarray([ifm_shape[3] * filter_m])
425
426 return [ifm_shape, filter_shape, bias_shape]
427
428 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000429 def tgFFT2d(testGen, op, rank, error_name=None):
430 pl, const = op["operands"]
431
432 if error_name != ErrorIf.WrongRank:
433 assert rank == 3
434 assert pl == 2 and const == 0
435
436 # IFM dimensions are NHW
437 ifm_shape = testGen.makeShape(rank)
438
439 # Select nearest lower power of two from input height and width
440 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
441 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
442
443 # Constrict the overall size of the shape when creating ERROR_IF tests
444 if error_name:
445 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
446
447 # Generate an invalid kernel that is not a power of two
448 if error_name == ErrorIf.KernelNotPowerOfTwo:
449 inc_h = 2 if ifm_shape[1] == 1 else 1
450 inc_w = 2 if ifm_shape[2] == 1 else 1
451 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
452 selected_inc = testGen.rng.choice(inc_choices)
453 ifm_shape[1] += selected_inc[0]
454 ifm_shape[2] += selected_inc[1]
455
456 ifm_shape = testGen.constrictBatchSize(ifm_shape)
457
458 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
459 if error_name == ErrorIf.FFTInputShapeMismatch:
460 modify_shape = testGen.rng.choice([0, 1])
461 # Only modify kernel (H, W)
462 modify_dim = testGen.rng.choice([1, 2])
463 ifm_shapes[modify_shape][modify_dim] *= 2
464
465 return [ifm_shapes[0], ifm_shapes[1]]
466
467 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000468 def tgRFFT2d(testGen, op, rank, error_name=None):
469 pl, const = op["operands"]
470
471 if error_name != ErrorIf.WrongRank:
472 assert rank == 3
473 assert pl == 1 and const == 0
474
475 # IFM dimensions are NHW
476 ifm_shape = testGen.makeShape(rank)
477
478 # Select nearest lower power of two from input height and width
479 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
480 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
481
482 # Constrict the overall size of the shape when creating ERROR_IF tests
483 if error_name:
484 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
485
486 # Generate an invalid kernel that is not a power of two
487 if error_name == ErrorIf.KernelNotPowerOfTwo:
488 # We must increment by 2 if current size is 1
489 inc_h = 2 if ifm_shape[1] == 1 else 1
490 inc_w = 2 if ifm_shape[2] == 1 else 1
491 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
492 selected_inc = testGen.rng.choice(inc_choices)
493 ifm_shape[1] += selected_inc[0]
494 ifm_shape[2] += selected_inc[1]
495
James Ward30124a82023-02-02 14:56:33 +0000496 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000497
498 return [ifm_shape]
499
500 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100501 def tgFullyConnected(testGen, op, rank, error_name=None):
502 pl, const = op["operands"]
503
504 if error_name != ErrorIf.WrongRank:
505 assert rank == 2
506
507 input_shape = testGen.makeShape(rank)
508
509 # Constrict the overall size of the shape when creating ERROR_IF tests
510 if error_name:
511 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
512
513 filter_oc = testGen.rng.integers(
514 low=testGen.args.tensor_shape_range[0],
515 high=testGen.args.tensor_shape_range[1],
516 size=1,
517 )[0]
518 filter_shape = np.asarray([filter_oc, input_shape[1]])
519
520 bias_shape = np.asarray([filter_oc])
521
522 return [input_shape, filter_shape, bias_shape]
523
524 @staticmethod
525 def tgMatmul(testGen, op, rank, error_name=None):
526 pl, const = op["operands"]
527
528 if error_name != ErrorIf.WrongRank:
529 assert rank == 3
530 assert pl == 2 and const == 0
531
532 a_shape = testGen.makeShape(rank)
533
534 # Constrict the overall size of the shape when creating ERROR_IF tests
535 if error_name:
536 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
537
538 # Get a random number for b_oc even if target shape is defined
539 b_oc = np.int32(
540 testGen.rng.integers(
541 low=testGen.args.tensor_shape_range[0],
542 high=testGen.args.tensor_shape_range[1],
543 size=1,
544 )
545 )[0]
546 # If N or H is large let b_oc be 1 to reduce output tensor size
547 if max(a_shape) > 1000:
548 b_oc = 1
549
550 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
551 return [a_shape, b_shape]
552
553 @staticmethod
554 def tgConcat(testGen, opName, rank, error_name=None):
555 pl, const = opName["operands"]
556 shape = testGen.makeShape(rank)
557
558 # Create extra tensors to concat.
559 # Take into account value of pl when getting maximum number of concats
560 num_tensors = testGen.randInt(0, 4)
561 shape_list = []
562 for i in range(pl + const + num_tensors):
563 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
564 remove = testGen.rng.choice([True, False])
565 wrongShape = shape.copy()
566
567 if remove and len(shape) > 1:
568 wrongShape = wrongShape[1:]
569 else:
570 wrongShape = list(wrongShape)
571 wrongShape.append(testGen.rng.integers(1, 10))
572
573 shape_list.append(wrongShape)
574 else:
575 shape_list.append(shape.copy())
576
577 return shape_list
578
579 @staticmethod
580 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
581 if error_name in [
582 ErrorIf.AxisSmallerZero,
583 ErrorIf.AxisLargerRank,
584 ErrorIf.ConcatInputRankMismatch,
585 ]:
586 return shapeList
587
588 # Split concat shape along axis to allow for multiple const inputs
589 # without making too many large tensors
590 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
591 # If axis can't be split we still need to invalidate other dimensions
592 if error_name == ErrorIf.ConcatInputDimMismatch:
593 for shape in shapeList[1:]:
594 # Negative test shapeLists are created individually for each test,
595 # so no need to copy the shape before altering it.
596 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
597 return shapeList
598
599 # Create copy of shape we are going to split (so we don't alter shapeList)
600 shape = shapeList[0].copy()
601 # Add original shape as first input
602 new_shapeList = [shape.copy()]
603 length_on_axis = shape[axis]
604 remaining_length = length_on_axis
605 for i in range(len(shapeList) - 2):
606 # Calculate split on axis and remaining value
607 split_shape_val = int(shape[axis] / 2)
608 remaining_length = remaining_length - split_shape_val
609
610 # Append new shape, and set remaining shape
611 shape[axis] = split_shape_val
612 new_shapeList.append(shape.copy())
613
614 # invalidate dimensions
615 if error_name == ErrorIf.ConcatInputDimMismatch:
616 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
617 else:
618 shape[axis] = remaining_length
619
620 if i == len(shapeList) - 3:
621 new_shapeList.append(shape.copy())
622
623 return new_shapeList
624
625
626class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100627 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100628
629 def __init__(self):
630 pass
631
Jeremy Johnson1271c442023-09-05 11:39:26 +0100632 class TVGInfo:
633 """Enhanced tensor values information including data gen dict."""
634
635 def __init__(self, tensorList, dataGenDict):
636 self.tensorList = tensorList
637 self.dataGenDict = dataGenDict
638
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100639 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000640 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100641 pCount, cCount = op["operands"]
642
643 tens = []
644 tens.extend(
645 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
646 )
647 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
648
649 return tens
650
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100651 # Default high value for random numbers
652 TVG_FLOAT_HIGH_VALUE = {
653 DType.FP32: (1 << 128) - (1 << (127 - 23)),
654 DType.FP16: (1 << 16) - (1 << (15 - 10)),
655 DType.BF16: (1 << 128) - (1 << (127 - 7)),
656 }
657
Jeremy Johnson30476252023-11-20 16:15:30 +0000658 # Default lowest normal values for random numbers
659 TVG_FLOAT_LOW_VALUE = {
660 DType.FP32: np.exp2(-126),
661 DType.FP16: np.exp2(-14),
662 DType.BF16: np.exp2(-126),
663 }
664
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100665 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +0000666 def _get_data_range(testGen, dtype, highValueLookup, lowValueLookup=None):
667 # Return a tuple of (low,high) data range values for the given data
668 # type using a combination of per operator table limits, data limits
669 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000670 if dtype in highValueLookup:
Jeremy Johnson30476252023-11-20 16:15:30 +0000671 type_range = testGen.getDTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000672 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000673 if lowValueLookup is not None and dtype in lowValueLookup:
674 low_val = lowValueLookup[dtype]
675 else:
676 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000677 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000678 # respecting the default ranges if more/less than the low/high
679 # values
680 data_range = (
681 max(low_val, type_range[0]),
682 min(high_val, type_range[1]),
683 )
684 if data_range[0] > data_range[1]:
685 # Invalid data range from low to high created due to user
686 # constraints revert to using internal ranges as they are
687 # known to work
688 msg = f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
689 warnings.warn(msg)
690 data_range = (low_val, high_val)
691 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000692 return None
693
694 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100695 def tvgLazyGenDefault(
696 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
697 ):
698 # Variable inputs versus constants
699 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000700 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100701
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100702 if (
703 error_name is not None
704 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100705 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100706 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000707 # Fall back to internal data gen when dealing with unsupported types or ops
708 data_range = argsDict["data_range"] if "data_range" in argsDict else None
709 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000710 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000711 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000712 if "data_range_list" in argsDict:
713 data_range = argsDict["data_range_list"][idx]["range"]
714 roundMode = (
715 "round" in argsDict["data_range_list"][idx]
716 and argsDict["data_range_list"][idx]["round"] is True
717 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000718 if data_range is not None and dtype not in (
719 DType.FP16,
720 DType.FP32,
721 DType.BF16,
722 ):
723 # Change from inclusive to exclusive range
724 data_range = (data_range[0], data_range[1] + 1)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000725 # Ignore lazy data gen option and create data array using any range limits
726 arr = testGen.getRandTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000727 if roundMode:
728 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000729 if idx < pCount:
730 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
731 else:
732 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100733
Jeremy Johnson1271c442023-09-05 11:39:26 +0100734 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
735
736 # Create data generator meta-data
737 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100738 tens_data = {
739 "version": "0.1",
740 "tensors": {},
741 }
742 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100743 for idx, shape in enumerate(shapeList):
744
745 tens_meta = {}
746 tens_meta["generator"] = gtu.DataGenType(dg_type).name
747 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
748 tens_meta["shape"] = [int(i) for i in shape]
749 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100750 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100751
752 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100753 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100754 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100755 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100756
757 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
758 info = {}
759 # TODO - generate seed for this generator based on test
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100760 info["rng_seed"] = 42
Jeremy Johnson30476252023-11-20 16:15:30 +0000761
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000762 data_range = None
Jeremy Johnson30476252023-11-20 16:15:30 +0000763 if "data_range_list" in argsDict:
764 data_range = argsDict["data_range_list"][idx]["range"]
765 if "round" in argsDict["data_range_list"][idx]:
766 info["round"] = argsDict["data_range_list"][idx]["round"]
767 elif "data_range" in argsDict:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100768 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000769
770 if data_range is None:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100771 data_range = testGen.getDTypeRange(
772 dtypeList[idx], high_inclusive=True
773 )
774 info["range"] = [str(v) for v in data_range]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100775 tens_meta["pseudo_random_info"] = info
776 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
777 info = {}
778 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100779 info["ks"] = int(argsDict["ks"])
780 if "acc_type" in argsDict:
781 # Convert type number into JSON name
782 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
783 "json"
784 ]
785 if "kernel" in argsDict:
786 info["kernel"] = [int(k) for k in argsDict["kernel"]]
787 if "axis" in argsDict:
788 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100789 tens_meta["dot_product_info"] = info
790 else:
791 # TODO - other data gen type
792 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100793
794 # Using the finished generate config meta data - generate the data if
795 # needed and assign a tensor name from the serializer
796
797 # Need to generate data when not lazy or for the bias tensor as we need
798 # to work out if the bias data is non-zero for compliance
799 if not testGen.args.lazy_data_gen or (
800 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
801 ):
802 # Give this tensor a temporary name until we get one from the serializer
803 temp_name = f"placeholder_{idx}"
804 dg_tens_meta[temp_name] = tens_meta
805 # Create data now using the temporary name to access meta details
806 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
807 # Remove the item as we will give it the correct name later
808 del dg_tens_meta[temp_name]
809
810 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
811 # The KS value used by compliance verification is altered when the
812 # bias data is non-zero
813 if max(abs(data)) > 0.0:
814 argsDict["ksb"] = argsDict["ks"] + 1
815
816 if testGen.args.lazy_data_gen:
817 data = None
818
819 if tens_meta["input_type"] == "VARIABLE":
820 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
821 else:
822 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
823
824 tens_ser_list.append(tens)
825 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100826 dg_tens_meta[tens.name] = tens_meta
827
Jeremy Johnson1271c442023-09-05 11:39:26 +0100828 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
829
830 @staticmethod
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000831 def tvgNegate(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100832 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000833 # Integer test
834 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100835 pCount, cCount = op["operands"]
836 assert (
837 pCount == 1 and cCount == 0
838 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100839 # Must create tensors with values within accumulator (int32) negatable
840 # range
841 max_val = (1 << 31) - 1
842 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100843 arr = np.int32(
844 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
845 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000846 tens_ser_list = []
847 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100848 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
849 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000850 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100851 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000852 # ERROR_IF or floating point test
853 return TosaTensorValuesGen.tvgLazyGenDefault(
854 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100855 )
856
Jeremy Johnson30476252023-11-20 16:15:30 +0000857 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000858 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
859 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
860 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
861 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
862 }
863
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100864 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000865 def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100866 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000867 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100868 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000869 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100870 pCount, cCount = op["operands"]
871 assert (
872 pCount == 2 and cCount == 0
873 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000874 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100875 add = op["op"] == Op.ADD
876 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
877 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
878 if add:
879 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
880 else:
881 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
882
883 # Work out the saturation limits
884 max_i32 = (1 << 31) - 1
885 min_i32 = -(1 << 31)
886 max_arr = np.full(shapeList[1], max_i32)
887 min_arr = np.full(shapeList[1], min_i32)
888
889 # Find how much values exceed the maximum/minimums
890 sat_max_arr = np.maximum(res_arr - max_arr, 0)
891 sat_min_arr = np.minimum(res_arr - min_arr, 0)
892
893 if not add:
894 # Swap saturation values and negate values as we need to perform opposite operations
895 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
896
897 # Create new array of unsaturated values by clipping values as needed
898 b_unsat_arr = b_arr
899 if (sat_max_arr != 0).any():
900 # Clip values that cause saturation
901 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
902 # Reduce axes in unsaturated tensor to match original tensor
903 for axis, dim in enumerate(b_arr.shape):
904 if dim != b_unsat_arr.shape[axis]:
905 assert (
906 dim == 1
907 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
908 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
909
910 if (sat_min_arr != 0).any():
911 # Clip values that cause saturation
912 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
913 # Reduce axes in unsaturated tensor to match original tensor
914 for axis, dim in enumerate(b_arr.shape):
915 if dim != b_unsat_arr.shape[axis]:
916 assert (
917 dim == 1
918 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
919 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
920
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000921 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100922 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
923 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000924 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100925 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
926 )
927
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000928 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100929 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000930 # ERROR_IF or floating point test
931 data_range = TosaTensorValuesGen._get_data_range(
932 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
933 )
934 if data_range:
935 argsDict["data_range"] = data_range
936
937 return TosaTensorValuesGen.tvgLazyGenDefault(
938 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100939 )
940
941 @staticmethod
942 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000943 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100944 ):
945 if dtypeList[0] in (
946 DType.INT32,
947 DType.INT16,
948 DType.INT8,
949 ):
950 # Limit input tensors with cond_if_binary or while_loop to stop
951 # saturation of add/sub ops with int32 and keep all logical shift
952 # values between 0 to 31 for int16 or int8
953 pCount, cCount = op["operands"]
954 pRemain = pCount
955 placeholders = []
956 for idx, shape in enumerate(shapeList[:]):
957 if dtypeList[0] == DType.INT32:
958 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
959 else:
960 arr = np.int32(
961 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
962 )
963 if pRemain > 0:
964 placeholders.append(
965 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
966 )
967 pRemain -= 1
968 else:
969 placeholders.append(
970 testGen.ser.addConst(shape, dtypeList[idx], arr)
971 )
972
973 return placeholders
974 else:
975 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000976 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100977 )
978
979 @staticmethod
980 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000981 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100982 ):
983 pCount, cCount = op["operands"]
984 # Force value of operand[1] to be within [0, num_bits]
985 assert (
986 pCount == 2 and cCount == 0
987 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
988
989 placeholders = []
990 for idx, shape in enumerate(shapeList[:]):
991 if idx == 1:
992 if dtypeList[idx] == DType.INT8:
993 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
994 elif dtypeList[idx] == DType.INT16:
995 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
996 elif dtypeList[idx] == DType.INT32:
997 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
998 elif error_name == ErrorIf.WrongInputType:
999 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
1000 else:
1001 raise Exception("OpArithmeticRightShift: invalid input dtype")
1002 else:
1003 arr = testGen.getRandTensor(shape, dtypeList[idx])
1004 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
1005
1006 return placeholders
1007
1008 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001009 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001010 # Set datatype of condition tensor to boolean
1011 dtypeList[0] = DType.BOOL
1012
1013 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001014 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001015 )
1016
1017 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001018 def tvgIntDiv(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001019 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001020 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001021 pCount, cCount = op["operands"]
1022 assert (
1023 pCount == 2 and cCount == 0
1024 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1025
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001026 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001027
1028 # Two invalid cases for Op.INTDIV:
1029 # 1. divisor == 0
1030 # 2. dividend == -(1<<31) and divisor == -1
1031 while True:
1032 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1033 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
1034
1035 if (divisor_arr == 0).any():
1036 continue
1037
1038 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1039 continue
1040
1041 break
1042
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001043 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001044 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1045 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001046 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001047 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1048 )
1049
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001050 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001051 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001052 return TosaTensorValuesGen.tvgLazyGenDefault(
1053 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001054 )
1055
Jeremy Johnson30476252023-11-20 16:15:30 +00001056 # Set the MUL data range to the square root of the largest value
1057 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001058 TVG_FLOAT_HIGH_VALUE_MUL = {
1059 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1060 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1061 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1062 }
1063
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001064 @staticmethod
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001065 def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1066 if error_name is not None or dtypeList[0] in (
1067 DType.FP16,
1068 DType.BF16,
1069 DType.FP32,
1070 ):
1071 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001072 data_range = TosaTensorValuesGen._get_data_range(
1073 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
1074 )
1075 if data_range:
1076 argsDict["data_range"] = data_range
1077
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001078 return TosaTensorValuesGen.tvgLazyGenDefault(
1079 testGen, opName, dtypeList, shapeList, argsDict, error_name
1080 )
1081 else:
1082 # Integer test
1083 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001084 pCount, cCount = op["operands"]
1085 assert (
1086 pCount == 2 and cCount == 0
1087 ), "Op.MUL must have 2 placeholders, 0 consts"
1088
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001089 tens_ser_list = []
1090
1091 # Make sure multiply result in int32 range
1092 shift = argsDict["shift"]
1093 if dtypeList[0] == DType.INT8:
1094 num_bits = 8
1095 elif dtypeList[0] == DType.INT16:
1096 num_bits = 16
1097 elif dtypeList[0] == DType.INT32:
1098 num_bits = 32
1099 elif error_name == ErrorIf.WrongInputType:
1100 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001101 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001102 raise Exception("OpMul: invalid input dtype")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001103
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001104 for idx, shape in enumerate(shapeList[:]):
1105 low = -(2 ** (num_bits - 1))
1106 high = (2 ** (num_bits - 1)) - 1
1107
1108 a_arr = np.int32(
1109 testGen.rng.integers(low=low, high=high, size=shapeList[0])
1110 )
1111 b_arr = np.int32(
1112 testGen.rng.integers(low=low, high=high, size=shapeList[1])
1113 )
1114
1115 i = 0
1116 while True:
1117
1118 a_arr_64 = a_arr.astype(np.int64)
1119 b_arr_64 = b_arr.astype(np.int64)
1120
1121 if shift > 0:
1122 rounding = 1 << (shift - 1)
1123 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001124 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001125 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001126
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001127 if (result_arr > -(2**31)).all() and (
1128 result_arr <= ((2**31) - 1)
1129 ).all():
1130 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001131
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001132 i = i + 1
1133 a_arr = a_arr // 2
1134 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001135
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001136 tens_ser_list.append(
1137 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001138 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001139 tens_ser_list.append(
1140 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1141 )
1142
1143 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001144
1145 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001146 def tvgConcat(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001147 count = len(shapeList) - testGen.args.num_const_inputs_concat
1148 if count < 1:
1149 count = 1
1150 if testGen.args.num_const_inputs_concat == 0:
1151 count = len(shapeList)
1152
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001153 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001154 testGen, shapeList, argsDict["axis"], error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001155 )
1156
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001157 tens_ser_list = []
1158 tens_ser_list.extend(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001159 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
1160 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001161 tens_ser_list.extend(
1162 testGen.buildConstTensors(shapeList[count:], dtypeList[count:])
1163 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001164
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001165 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001166
1167 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001168 def tvgLogicalShift(
1169 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1170 ):
1171 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001172 pCount, cCount = op["operands"]
1173 assert (
1174 pCount == 2 and cCount == 0
1175 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
1176 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1177 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001178 tens_ser_list = []
1179 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001180 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1181 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001182 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001183 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1184 )
1185
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001186 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001187
1188 @staticmethod
Jeremy Johnsona0150012023-11-15 15:52:06 +00001189 def tvgEqual(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1190 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1191 # Integer
1192 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001193 pCount, cCount = op["operands"]
1194 assert (
1195 pCount == 2 and cCount == 0
1196 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001197
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001198 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1199 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001200
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001201 # Using random numbers means that it will be very unlikely that
1202 # there are any matching (equal) values, therefore force that
1203 # there are twice the number of matching values as the tensor rank
1204 for num in range(0, len(shapeList[0]) * 2):
1205 a_index = []
1206 b_index = []
1207 # Choose an index in each axis for the whole shape
1208 for axis in range(0, len(shapeList[0])):
1209 # Index can be up to the largest dimension in both shapes
1210 index = np.int32(
1211 testGen.rng.integers(
1212 0, max(shapeList[0][axis], shapeList[1][axis])
1213 )
1214 )
1215 # Reduce the index down to a shape's dim for broadcasting
1216 a_index.append(min(shapeList[0][axis] - 1, index))
1217 b_index.append(min(shapeList[1][axis] - 1, index))
1218
1219 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1220
Jeremy Johnsona0150012023-11-15 15:52:06 +00001221 tens_ser_list = []
1222 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001223 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1224 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001225 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001226 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1227 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001228 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001229 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001230 # ERROR_IF or floating point test
1231 return TosaTensorValuesGen.tvgLazyGenDefault(
1232 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001233 )
1234
1235 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001236 def tvgReduceSum(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001237 dtype = dtypeList[0]
1238 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001239 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001240 pCount, cCount = op["operands"]
1241 assert (
1242 pCount == 1 and cCount == 0
1243 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1244 # Limit values so that the sum cannot exceed the range of an int32 during
1245 # summation of any axis
1246 range_val = int((1 << 31) / max(shapeList[0]))
1247 values_arr = np.int32(
1248 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1249 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001250 tens_ser_list = []
1251 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001252 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001253 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001254 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001255 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001256 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001257 if (
1258 error_name is None
1259 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1260 ):
1261 # Limit ranges for (non error & non compliance) tests by using
1262 # values that can be summed on any axis to not hit infinity
1263 highval_lookup = {
1264 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1265 / max(shapeList[0])
1266 }
1267 data_range = TosaTensorValuesGen._get_data_range(
1268 testGen, dtype, highval_lookup
1269 )
1270 assert data_range is not None
1271 argsDict["data_range"] = data_range
1272
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001273 return TosaTensorValuesGen.tvgLazyGenDefault(
1274 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001275 )
1276
Jeremy Johnson30476252023-11-20 16:15:30 +00001277 # Set the POW exponent high data range
1278 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1279 DType.FP32: 10.0,
1280 DType.FP16: 10.0,
1281 DType.BF16: 10.0,
1282 }
1283 # POW highest base value (within a safe margin of error) that can be raised
1284 # to +ve exponent that doesn't become Infinity
1285 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1286 DType.FP32: math.floor(
1287 math.pow(
1288 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1289 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1290 )
1291 ),
1292 DType.FP16: math.floor(
1293 math.pow(
1294 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1295 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1296 )
1297 ),
1298 DType.BF16: math.floor(
1299 math.pow(
1300 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1301 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1302 )
1303 ),
1304 }
1305 # POW lowest base value (within a safe margin of error) that can be raised
1306 # to -ve exponent that doesn't become Infinity
1307 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1308 DType.FP32: math.ceil(
1309 math.pow(
1310 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1311 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1312 )
1313 * 1000
1314 )
1315 / 1000,
1316 DType.FP16: math.ceil(
1317 math.pow(
1318 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1319 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1320 )
1321 * 1000
1322 )
1323 / 1000,
1324 DType.BF16: math.ceil(
1325 math.pow(
1326 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1327 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1328 )
1329 * 1000
1330 )
1331 / 1000,
1332 }
1333
1334 @staticmethod
1335 def tvgPow(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1336 if error_name is not None:
1337 return TosaTensorValuesGen.tvgLazyGenDefault(
1338 testGen, opName, dtypeList, shapeList, argsDict, error_name
1339 )
1340 dtype = dtypeList[0]
1341 # Different ranges for POW
1342 test_set = argsDict["s"]
1343 if test_set == 0:
1344 # Positive base with fractional exponent
1345 base_range = TosaTensorValuesGen._get_data_range(
1346 testGen,
1347 dtype,
1348 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1349 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1350 )
1351 exp_range = TosaTensorValuesGen._get_data_range(
1352 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1353 )
1354 exp_round = False
1355 else:
1356 # Integer exponent
1357 exp_range = TosaTensorValuesGen._get_data_range(
1358 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1359 )
1360 exp_round = True
1361 if test_set == 1:
1362 # Positive base
1363 base_range = TosaTensorValuesGen._get_data_range(
1364 testGen,
1365 dtype,
1366 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1367 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1368 )
1369 else:
1370 assert test_set == 2
1371 # Negative base
1372 # Supply new look up tables with negative values
1373 base_range = TosaTensorValuesGen._get_data_range(
1374 testGen,
1375 dtype,
1376 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1377 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1378 )
1379
1380 data_range_list = (
1381 {
1382 "range": base_range,
1383 },
1384 {
1385 "range": exp_range,
1386 "round": exp_round,
1387 },
1388 )
1389 argsDict["data_range_list"] = data_range_list
1390 return TosaTensorValuesGen.tvgLazyGenDefault(
1391 testGen, opName, dtypeList, shapeList, argsDict, error_name
1392 )
1393
1394 @staticmethod
1395 def tvgLogRsqrt(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1396 # LOG & RSQRT data range from lowest expressible positive number to
1397 # largest to avoid NaNs
1398 data_range = TosaTensorValuesGen._get_data_range(
1399 testGen,
1400 dtypeList[0],
1401 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1402 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1403 )
1404 if data_range:
1405 argsDict["data_range"] = data_range
1406
1407 return TosaTensorValuesGen.tvgLazyGenDefault(
1408 testGen, opName, dtypeList, shapeList, argsDict, error_name
1409 )
1410
1411 # Set the EXP data range to the log of the largest to smallest values
1412 # to avoid infinities or making the result zero
1413 TVG_FLOAT_HIGH_VALUE_EXP = {
1414 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1415 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1416 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1417 }
1418 TVG_FLOAT_LOW_VALUE_EXP = {
1419 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1420 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1421 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1422 }
1423
1424 @staticmethod
1425 def tvgExp(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1426 data_range = TosaTensorValuesGen._get_data_range(
1427 testGen,
1428 dtypeList[0],
1429 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1430 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1431 )
1432 if data_range:
1433 argsDict["data_range"] = data_range
1434
1435 return TosaTensorValuesGen.tvgLazyGenDefault(
1436 testGen, opName, dtypeList, shapeList, argsDict, error_name
1437 )
1438
1439 @staticmethod
1440 def tvgFullyConnected(
1441 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1442 ):
1443 dtype = dtypeList[0]
1444 if (
1445 error_name is None
1446 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001447 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001448 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001449 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001450 # Limit ranges for (non error & non compliance) FP tests by using
1451 # values that can be multiplied on any axis to not hit infinity/NaN
1452 IC = shapeList[0][1]
1453 highval_lookup = {
1454 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1455 }
1456 data_range = TosaTensorValuesGen._get_data_range(
1457 testGen, dtype, highval_lookup
1458 )
1459 assert data_range is not None
1460 argsDict["data_range"] = data_range
1461
1462 return TosaTensorValuesGen.tvgLazyGenDefault(
1463 testGen, opName, dtypeList, shapeList, argsDict, error_name
1464 )
1465
Jeremy Johnson708da822023-11-15 16:25:45 +00001466 @staticmethod
1467 def tvgCast(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1468 in_dtype = dtypeList[0]
1469 out_dtype = argsDict["out_type"]
1470 # Create look up to limit input tensor to output type maximums to avoid
1471 # FP infinities and saturation of integers
1472 out_range = testGen.getDTypeRange(out_dtype, high_inclusive=True)
1473 highval_lookup = {in_dtype: out_range[1]}
1474 data_range = TosaTensorValuesGen._get_data_range(
1475 testGen,
1476 in_dtype,
1477 highval_lookup,
1478 )
1479
1480 assert data_range is not None
1481 argsDict["data_range"] = data_range
1482
1483 return TosaTensorValuesGen.tvgLazyGenDefault(
1484 testGen, opName, dtypeList, shapeList, argsDict, error_name
1485 )
1486
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001487 @staticmethod
1488 def tvgGather(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1489 K = shapeList[0][1]
1490
1491 # Fix the type of the indices tensor
1492 dtypeList[1] = DType.INT32
1493
1494 dtype = dtypeList[0]
1495 if not gtu.dtypeIsSupportedByCompliance(dtype):
1496 # Test unsupported by data generator
1497 op = testGen.TOSA_OP_LIST[opName]
1498 pCount, cCount = op["operands"]
1499 assert (
1500 pCount == 2 and cCount == 0
1501 ), "Op.GATHER must have 2 placeholders, 0 consts"
1502
1503 tens_ser_list = []
1504 for idx, shape in enumerate(shapeList):
1505 dtype = dtypeList[idx]
1506 if idx != 1:
1507 arr = testGen.getRandTensor(shape, dtype)
1508 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1509 else:
1510 # Limit data range of indices tensor upto K (exclusive)
1511 arr = testGen.getRandTensor(shape, dtype, (0, K))
1512 # To match old functionality - create indices as CONST
1513 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1514
1515 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1516
1517 else:
1518 # ERROR_IF or floating point test
1519 # Use inclusive values upto index K for indices tensor
1520 data_range_list = (
1521 {"range": None},
1522 {"range": (0, K - 1)},
1523 )
1524 argsDict["data_range_list"] = data_range_list
1525
1526 return TosaTensorValuesGen.tvgLazyGenDefault(
1527 testGen, opName, dtypeList, shapeList, argsDict, error_name
1528 )
1529
1530 @staticmethod
1531 def tvgScatter(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1532 K = shapeList[0][1]
1533 W = shapeList[2][1]
1534
1535 # Work out an indices tensor here with data that doesn't exceed the
1536 # dimension K of the values_in tensor and does NOT repeat the same K
1537 # location as needed by the spec:
1538 # "It is not permitted to repeat the same output index within a single
1539 # SCATTER operation and so each output index occurs at most once."
1540 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1541
1542 # Fix the type of the indices tensor
1543 dtypeList[1] = DType.INT32
1544
1545 dtype = dtypeList[0]
1546 if not gtu.dtypeIsSupportedByCompliance(dtype):
1547 # Test unsupported by data generator
1548 op = testGen.TOSA_OP_LIST[opName]
1549 pCount, cCount = op["operands"]
1550 assert (
1551 pCount == 3 and cCount == 0
1552 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1553
1554 tens_ser_list = []
1555 for idx, shape in enumerate(shapeList):
1556 dtype = dtypeList[idx]
1557 if idx != 1:
1558 arr = testGen.getRandTensor(shape, dtype)
1559 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1560 else:
1561 # Create the indices array
1562 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1563 arr = []
1564 for n in range(shape[0]):
1565 # Get a shuffled list of output indices (0 to K-1) and
1566 # limit length to W
1567 arr.append(testGen.rng.permutation(K)[:W])
1568 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1569 # To match old functionality - create indices as CONST
1570 tens_ser_list.append(
1571 testGen.ser.addConst(shape, dtype, indices_arr)
1572 )
1573
1574 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1575
1576 else:
1577 # ERROR_IF or floating point test
1578 # Use inclusive values upto index K for indices tensor
1579 data_range_list = (
1580 {"range": None},
1581 {"range": (0, K - 1)},
1582 {"range": None},
1583 )
1584 argsDict["data_range_list"] = data_range_list
1585
1586 return TosaTensorValuesGen.tvgLazyGenDefault(
1587 testGen, opName, dtypeList, shapeList, argsDict, error_name
1588 )
1589
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001590
1591class TosaArgGen:
1592 """Argument generators create exhaustive or random lists of attributes for
1593 operators that take attributes or other parameters.
1594
1595 The return value is a list of (descriptive_name, [arglist]) tuples where
1596 the descriptive_name is appended to the test name and the arglist is expanded
1597 as arguments to the operator build function.
1598 """
1599
1600 def __init__(self):
1601 pass
1602
1603 @staticmethod
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001604 def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001605 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001606 if (
1607 error_name is None
1608 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1609 and gtu.dtypeIsSupportedByCompliance(dtype)
1610 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001611 if dtype in [DType.FP16, DType.FP32, DType.BF16]:
1612 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1613 else:
1614 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1615 else:
1616 # Error test or No data generator types listed - assume random
1617 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1618
1619 # Expand arg list with other data generator types
1620 new_arg_list = []
1621 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001622 for arg_str, args_dict in arg_list:
1623 args_dict["dg_type"] = dg_type
Jeremy Johnson1271c442023-09-05 11:39:26 +01001624 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001625 if error_name is None:
1626 num_test_sets = (
1627 args_dict["num_test_sets"]
1628 if "num_test_sets" in args_dict
1629 else 0
1630 )
1631 else:
1632 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001633
1634 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1635 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001636 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001637 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001638 shape_info = (
1639 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1640 if "shape" in args_dict
1641 else ""
1642 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001643 print(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001644 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001645 )
1646 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001647 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001648 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001649 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001650
Jeremy Johnson30476252023-11-20 16:15:30 +00001651 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1652
1653 if num_test_sets > 0:
1654 for s in range(0, num_test_sets):
1655 new_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001656 new_args_dict = args_dict.copy()
1657 new_args_dict["s"] = s
1658 new_arg_list.append((new_arg_str, new_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001659 else:
1660 # Default is a single test
1661 new_arg_list.append((arg_str, args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001662
1663 return new_arg_list
1664
1665 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001666 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1667 """A trivial argument generator for operators that don't take any
1668 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001669 arg_list = TosaArgGen._add_data_generators(
1670 testGen,
1671 opName,
1672 dtype,
1673 [("", {})],
1674 error_name,
1675 )
1676 # Return list of tuples: (arg_str, args_dict)
1677 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001678
1679 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +00001680 def agPow(testGen, opName, shapeList, dtype, error_name=None):
1681 """Pow operator needs different test sets to cover random numbers
1682 without creating NaNs or Infs"""
1683 arg_list = TosaArgGen._add_data_generators(
1684 testGen,
1685 opName,
1686 dtype,
1687 [("", {"num_test_sets": 3})],
1688 error_name,
1689 )
1690 # Return list of tuples: (arg_str, args_dict)
1691 return arg_list
1692
1693 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001694 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1695 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001696 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001697 shape = shapeList[0]
1698
1699 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001700 # Set too small axis
1701 axes = [testGen.rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001702 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001703 # Set too large axis
1704 axes = [testGen.rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001705 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001706 # Create tests for each dimension
1707 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001708
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001709 opid = testGen.TOSA_OP_LIST[opName]["op"]
1710
1711 for a in axes:
1712 args_dict = {"axis": int(a)}
1713 if opid == Op.REDUCE_SUM:
1714 args_dict["dot_products"] = gtu.product(shape)
1715 args_dict["shape"] = shape
1716 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1717 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1718
1719 arg_list.append(("axis{}".format(a), args_dict))
1720
1721 arg_list = TosaArgGen._add_data_generators(
1722 testGen,
1723 opName,
1724 dtype,
1725 arg_list,
1726 error_name,
1727 )
1728 # Return list of tuples: (arg_str, args_dict)
1729 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001730
1731 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001732 def _calculate_sparsity(num_tests, sparsity_factor):
1733 sparsity = num_tests // sparsity_factor + 1
1734 # If there are only a small number of tests, just select them all
1735 if sparsity < 13:
1736 sparsity = 1
1737 # To get a variety of parameter combinations sparsity should not be a
1738 # multiple of 2, 3 or 5
1739 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1740 sparsity += 1
1741 return sparsity
1742
1743 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001744 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001745 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001746 arg_list = []
1747
Jeremy Johnson0c716862023-04-13 17:18:19 +01001748 if testGen.args.level8k and error_name is not None:
1749 # Don't produce negative large tests
1750 return arg_list
1751
1752 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001753 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001754 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001755 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001756
Jeremy Johnson1271c442023-09-05 11:39:26 +01001757 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001758
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001759 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01001760 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001761 depthwise = opName.startswith("depthwise")
1762
1763 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001764 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001765 if error_name != ErrorIf.WrongRank:
1766 assert len(ifm_shape) == rank
1767 assert len(filter_shape) == rank
1768
Jeremy Johnson0c716862023-04-13 17:18:19 +01001769 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001770 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001771 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01001772 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001773 # compliance size - KS
1774 k_size = gtu.product(k_shape)
1775 if not depthwise:
1776 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001777
Jeremy Johnson0c716862023-04-13 17:18:19 +01001778 if not testGen.args.level8k:
1779 # Generate comprehensive argument lists
1780 # - except for named errors, which use specific invalid value(s)
1781 if error_name == ErrorIf.PadSmallerZero:
1782 p_vals = [testGen.rng.choice(range(-5, 0))]
1783 else:
1784 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1785 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1786 if error_name == ErrorIf.StrideSmallerOne:
1787 # Can't use stride=0, as it is used to derive output shape, as a divisor
1788 s_vals = [testGen.rng.choice(range(-5, 0))]
1789 else:
1790 # Stride must be greater than 1 to force non-integer error
1791 startStride = (
1792 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001793 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001794 s_vals = [
1795 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1796 ]
1797 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1798 if error_name == ErrorIf.DilationSmallerOne:
1799 d_vals = [testGen.rng.choice(range(-5, 1))]
1800 else:
1801 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1802 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001803
Jeremy Johnson0c716862023-04-13 17:18:19 +01001804 if not error_name and testGen.args.oversize:
1805 # add some oversize argument values
1806 if max(ifm_shape) < 64:
1807 bigPadding = 9
1808 paddings.update(
1809 {
1810 x
1811 for x in itertools.product(
1812 *([[0, bigPadding]] * (k_rank * 2))
1813 )
1814 }
1815 )
1816 bigStride = 8
1817 strides.update(
1818 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1819 )
1820 bigDilation = 7
1821 dilations.update(
1822 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1823 )
1824 max_dim_size = None
1825
1826 # There are too many parameter combinations, so generate them sparsely,
1827 # very sparse for negative tests
1828 sparsity_factor = 2 if error_name else 120
1829 sparsity = TosaArgGen._calculate_sparsity(
1830 len(paddings) * len(strides) * len(dilations), sparsity_factor
1831 )
1832 else:
1833 # Only test 8k levels boundaries
1834 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1835 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1836 bigPadding = bigKernel
1837
1838 dilation_shape = [1] * k_rank
1839 pad_shape = [0] * k_rank * 2
1840 if conv3d:
1841 # Small stride apart from for big kernel (see below) to keep
1842 # tensor size/calculation small
1843 stride_shape = [1] * k_rank
1844 for idx in range(k_rank):
1845 pad_offset = idx * 2
1846 if k_shape[idx] == bigKernel:
1847 # Padding shape needs to account for tensor shape
1848 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1849 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1850 # Big stride to reduce output size
1851 stride_shape[idx] = bigKernel
1852 else:
1853 # Account for kernel size
1854 pad_shape[pad_offset] = k_shape[idx] - 1
1855 else:
1856 # Always have a large stride with extra padding and dilation to keep
1857 # tensor calculation reasonable
1858 stride_shape = [bigKernel] * k_rank
1859 for idx in range(k_rank):
1860 # Dilation shape must account for kernel size
1861 dilation_shape[idx] = bigKernel // k_shape[idx]
1862 # Padding shape needs to accommodate tensor/kernel & dilation
1863 pad_offset = idx * 2
1864 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1865 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1866
1867 strides = {tuple(stride_shape)}
1868 dilations = {tuple(dilation_shape)}
1869 paddings = {tuple(pad_shape)}
1870 # Create a limit for the output dimensions size
1871 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1872
1873 # Currently allow all combinations that are reasonable size
1874 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001875
1876 n = 0
1877 for s in sorted(list(strides)):
1878 for p in sorted(list(paddings)):
1879 for d in sorted(list(dilations)):
1880 if (
1881 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001882 # the padded shape must exceed the dilation * kernel to get a positive
1883 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001884 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
1885 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01001886 and (
1887 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01001888 or (
1889 (ifm_shape[3] - 1 + p[4] + p[5])
1890 > d[2] * (k_shape[2] - 1)
1891 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01001892 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001893 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001894 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01001895 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001896 for index in range(k_rank):
1897 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001898 partial = (
1899 ifm_shape[index + 1]
1900 - 1
1901 + p[pad_offset]
1902 + p[pad_offset + 1]
1903 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001904 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001905 remainders.append(partial % s[index])
1906 outputs.append((partial // s[index]) + 1)
1907
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001908 if (
1909 # the parameters must produce integer exact output
1910 error_name != ErrorIf.ConvOutputShapeNonInteger
1911 and max(remainders) == 0
1912 ) or (
1913 error_name == ErrorIf.ConvOutputShapeNonInteger
1914 and max(remainders) > 0
1915 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001916 if (
1917 max_dim_size is not None
1918 and max(outputs) >= max_dim_size
1919 ):
1920 # Test will consume too much memory - skip it
1921 continue
1922
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001923 # Compliance - number of dot product calculations
1924 if depthwise:
1925 # TODO - add support
1926 dots = 0
1927 else:
1928 dots = gtu.product(
1929 (ifm_shape[0], *outputs, filter_shape[0])
1930 )
1931 args_dict = {
1932 "acc_type": accum_dtype,
1933 "stride": s,
1934 "pad": p,
1935 "dilation": d,
1936 "kernel": k_shape,
1937 "ks": k_size,
1938 "dot_products": dots,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001939 "shape": ifm_shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001940 }
1941
Jeremy Johnson0c716862023-04-13 17:18:19 +01001942 # Support for larger values than 9 needs different delimiter
1943 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001944 arg_list.append(
1945 (
James Ward8b390432022-08-12 20:48:56 +01001946 "acc{}_st{}_pad{}_dilat{}".format(
1947 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001948 delim.join([str(x) for x in s]),
1949 delim.join([str(x) for x in p]),
1950 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001951 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001952 args_dict,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001953 )
1954 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001955 n += 1
1956
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001957 arg_list = TosaArgGen._add_data_generators(
1958 testGen,
1959 opName,
1960 dtypes[0],
1961 arg_list,
1962 error_name,
1963 )
1964 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001965 return arg_list
1966
1967 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001968 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1969
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001970 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001971 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01001972
1973 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001974 accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01001975 elif error_name == ErrorIf.WrongInputType:
1976 # Pick some potentially correct output dtype if input type is incorrect
1977 accum_dtype = DType.INT32
1978 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001979 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001980
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001981 # Set up compliance info
1982 args_dict = {
1983 "acc_type": accum_dtype,
1984 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
1985 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
1986 "shape": shapeList[0],
1987 }
1988
1989 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
1990
1991 arg_list = TosaArgGen._add_data_generators(
1992 testGen,
1993 opName,
1994 input_dtype,
1995 arg_list,
1996 error_name,
1997 )
1998 # Return list of tuples: (arg_str, args_dict)
1999 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002000
2001 @staticmethod
2002 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
2003 # Get valid accumulate type(s)
2004 if dtype == DType.INT8:
2005 accum_dtypes = [DType.INT32]
2006 elif dtype == DType.INT16:
2007 accum_dtypes = [DType.INT48]
2008 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002009 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002010 elif dtype == DType.BF16:
2011 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002012 elif dtype == DType.FP32:
2013 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01002014 elif error_name is None:
2015 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2016
2017 if error_name == ErrorIf.WrongOutputType:
2018 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson1271c442023-09-05 11:39:26 +01002019 accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002020 elif error_name == ErrorIf.WrongInputType:
2021 # Pick some potentially correct output dtype if input type is incorrect
2022 accum_dtypes = [DType.INT32]
2023
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002024 # Set up compliance info
2025 args_dict = {
2026 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2027 # Set dot_products = N*H*W
2028 "dot_products": gtu.product(
2029 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2030 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002031 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002032 }
2033
2034 # Create arg tuple of string and dict
2035 arg_list = []
2036 for a in accum_dtypes:
2037 d = args_dict.copy()
2038 d["acc_type"] = a
2039 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002040
2041 arg_list = TosaArgGen._add_data_generators(
2042 testGen,
2043 opName,
2044 dtype,
2045 arg_list,
2046 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002047 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002048 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002049 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002050
2051 @staticmethod
2052 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002053 arg_list = []
2054
Jeremy Johnson0c716862023-04-13 17:18:19 +01002055 if testGen.args.level8k and error_name is not None:
2056 # Don't produce negative large tests
2057 return arg_list
2058
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002059 ifm_shape = shapeList[0]
2060 filter_shape = shapeList[1]
2061
Jeremy Johnson1271c442023-09-05 11:39:26 +01002062 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002063
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002064 # Must be rank 4
2065 if error_name != ErrorIf.WrongRank:
2066 assert len(ifm_shape) == 4
2067 assert len(filter_shape) == 4
2068
Jeremy Johnson0c716862023-04-13 17:18:19 +01002069 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002070
Jeremy Johnson0c716862023-04-13 17:18:19 +01002071 if not testGen.args.level8k:
2072 # Generate comprehensive argument lists
2073 # - except for named errors, which use specific invalid value(s)
2074 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2075 if error_name == ErrorIf.PadLargerEqualKernel:
2076 max_filter_size = -max(k_shape[0], k_shape[1])
2077 p_vals = [
2078 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
2079 ]
2080 else:
2081 p_vals = [
2082 x
2083 for x in range(
2084 smallest_padding_size, testGen.args.max_conv_padding + 1
2085 )
2086 ]
2087 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2088 if error_name == ErrorIf.StrideSmallerOne:
2089 # Can't use stride=0, as it is used to derive output shape, as a divisor
2090 s_vals = [testGen.rng.choice(range(-5, 0))]
2091 else:
2092 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2093 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002094
Jeremy Johnson0c716862023-04-13 17:18:19 +01002095 if not error_name and testGen.args.oversize:
2096 # add some oversize argument values
2097 if max(ifm_shape) < 64:
2098 bigPadding = 9
2099 paddings.update(
2100 {
2101 x
2102 for x in itertools.product(
2103 *([[smallest_padding_size, bigPadding]] * 4)
2104 )
2105 }
2106 )
2107 bigStride = 8
2108 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2109
2110 # There are too many parameter combinations, so generate them sparsely,
2111 # very sparse for negative tests
2112 sparsity_factor = 2 if error_name else 10
2113 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2114 # If there are only a small number of tests, just select them all
2115 if sparsity < 13:
2116 sparsity = 1
2117 # To get a variety of parameter combinations sparsity should not be a
2118 # multiple of 2, 3 or 5
2119 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2120 sparsity += 1
2121 else:
2122 # Only test 8k levels boundaries
2123 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2124 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2125 bigPadding = bigKernel
2126
2127 pad_shape = [0] * (len(k_shape) * 2)
2128 stride_shape = [1] * len(k_shape)
2129 # The point at which input dimension combined with the stride will
2130 # create large output sizes!
2131 LARGE_SIZE = 2
2132 for idx in range(len(k_shape)):
2133 pad_offset = idx * 2
2134 if k_shape[idx] == bigKernel:
2135 # Set large stride
2136 stride_shape[idx] = bigKernel
2137 # Use negative output padding to reduce shape size
2138 pad_shape[pad_offset] = -(bigPadding - 1)
2139 if ifm_shape[idx + 1] > LARGE_SIZE:
2140 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2141 else:
2142 # The other dimension should be the bigKernel
2143 alt_idx = 1 - idx
2144 if (
2145 k_shape[alt_idx] == bigKernel
2146 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2147 ):
2148 # As the input is small, the large stride won't
2149 # affect the output so we can add some padding
2150 pad_shape[pad_offset + 1] = bigPadding
2151
2152 strides = {tuple(stride_shape)}
2153 paddings = {tuple(pad_shape)}
2154
2155 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002156 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002157
2158 n = 0
2159 for s in sorted(list(strides)):
2160 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07002161 if n % sparsity == 0:
2162 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002163 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2164 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07002165 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002166
2167 # Support for larger values than 9 needs different delimiter
2168 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07002169 arg_list.append(
2170 (
James Ward8b390432022-08-12 20:48:56 +01002171 "acc{}_st{}_pad{}_os{}".format(
2172 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002173 delim.join([str(x) for x in s]),
2174 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07002175 "x".join([str(x) for x in os]),
2176 ),
James Ward8b390432022-08-12 20:48:56 +01002177 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002178 )
TatWai Chong24594f52022-06-08 00:48:04 -07002179 )
2180 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002181
2182 return arg_list
2183
2184 @staticmethod
2185 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002186 rank = len(shapeList[0])
2187
2188 # Exhaustively test combinations of padding on each side of each dimension
2189 # - the range of padding values is defined by pad_min and pad_max
2190 # - for padding >9, the name format needs to be more distinctive
2191 pad_min, pad_max = 0, 1
2192 pad_values = [x for x in range(pad_min, pad_max + 1)]
2193 if error_name == ErrorIf.PadSmallerZero:
2194 pad_values = [x for x in range(-2, 0)]
2195 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2196 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2197
2198 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
2199 pad_const_int = testGen.getRandNumberDType(dtype)
2200 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00002201 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002202 pad_const_int = 0
2203 pad_const_fp = testGen.getRandNumberDType(dtype)
2204 else:
2205 return []
2206
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002207 list_shape_pad_values = list(shape_pad_values)
2208 # If we are producing tests for rank 6 or greater use sparsity
2209 if len(list_shape_pad_values) > 1024:
2210 sparsity_factor = 2 if error_name else 120
2211 sparsity = TosaArgGen._calculate_sparsity(
2212 len(list_shape_pad_values), sparsity_factor
2213 )
2214 else:
2215 sparsity = 1
2216
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002217 # Build arg list
2218 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002219 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002220 paddings = list(paddings)
2221 args_valid = True
2222
2223 if error_name == ErrorIf.PadSmallerZero:
2224 # Prevent negative output shapes while ensuring still testing for negative padding
2225 for i in range(rank):
2226 dim_after_padding = (
2227 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2228 )
2229 if dim_after_padding < 1:
2230 paddings[i] = (0, 0)
2231 if all([p > -1 for p in paddings[i]]):
2232 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002233 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01002234 name = "pad"
2235 for r in range(rank):
2236 before, after = paddings[r]
2237 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002238 args_dict = {
2239 "pad": np.array(paddings),
2240 "pad_const_int": pad_const_int,
2241 "pad_const_fp": pad_const_fp,
2242 }
2243 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002244
2245 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
2246 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002247
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002248 arg_list = TosaArgGen._add_data_generators(
2249 testGen,
2250 opName,
2251 dtype,
2252 arg_list,
2253 error_name,
2254 )
2255
2256 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002257 return arg_list
2258
2259 @staticmethod
2260 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
2261 arg_list = []
2262
2263 shape = shapeList[0]
2264 if error_name != ErrorIf.WrongRank:
2265 assert len(shape) == 4
2266
Jeremy Johnson0c716862023-04-13 17:18:19 +01002267 test_level8k = testGen.args.level8k and error_name is None
2268
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002269 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002270 startKernel = 2
2271 startPad = 0
2272 if not test_level8k:
2273 # Generate comprehensive argument lists
2274 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2275 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2276 # Stride must be greater than 1 to force non-integer error
2277 s_vals = [
2278 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2279 ]
2280 strides = {x for x in itertools.product(*([s_vals] * 2))}
2281 k_vals = [
2282 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2283 ]
2284 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2285 max_dim_size = None
2286 else:
2287 # Only test 8k levels
2288 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2289 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2290 strides = {(1, bigStride), (bigStride, 4)}
2291 kernels = {(1, bigKernel), (bigKernel, 3)}
2292 paddings = set()
2293 for s in sorted(list(strides)):
2294 for k in sorted(list(kernels)):
2295 padding = []
2296 for idx in range(len(k)):
2297 total_padding = s[idx] - shape[idx + 1] + k[idx]
2298 while total_padding < 0:
2299 # Must meet: shape + padding > kernel
2300 total_padding += s[idx]
2301 if total_padding < k[idx]:
2302 padding.extend([0, total_padding])
2303 else:
2304 # Note this may produce padding >= k[idx] which is not
2305 # allowed - but will be ignored in the creation loop below
2306 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2307 paddings.add(tuple(padding))
2308 # Create a limit for the output dimensions size
2309 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002310
James Ward8b390432022-08-12 20:48:56 +01002311 if opName == "max_pool2d":
2312 accum_dtypes = [None] # max_pool has no accumulate dtype
2313 elif dtype == DType.INT8 or dtype == DType.INT16:
2314 accum_dtypes = [DType.INT32]
2315 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002316 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002317 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002318 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01002319 elif error_name is None:
2320 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2321 else:
2322 # Set to something for the ErrorIf case which has
2323 # incorrect input data-type
2324 accum_dtypes = [DType.INT32]
2325
Jeremy Johnson0c716862023-04-13 17:18:19 +01002326 if not test_level8k:
2327 if testGen.args.oversize:
2328 # add some oversize argument values
2329 bigStride = 7
2330 bigKernel = 9
2331 strides.update(
2332 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002333 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002334 kernels.update(
2335 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2336 )
2337 if max(shape) < 64:
2338 # padding must be less than the kernel size
2339 bigPadding = bigKernel - 1
2340 paddings.update(
2341 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2342 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002343
Jeremy Johnson0c716862023-04-13 17:18:19 +01002344 # There are too many parameter combinations, so generate them sparsely,
2345 # very sparse for negative tests
2346 sparsity_factor = 2 if error_name else 500
2347 sparsity = (
2348 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2349 )
2350 else:
2351 # We have already limited test output combinations for 8k tests
2352 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002353
James Ward8b390432022-08-12 20:48:56 +01002354 arg_str = (
2355 "acc{}_st{}_kern{}_pad{}"
2356 if accum_dtypes[0] is not None
2357 else "st{}_kern{}_pad{}"
2358 )
2359
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002360 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002361 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002362 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002363
2364 # Support for larger values than 9 needs different delimiter
2365 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002366 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002367 delim.join([str(x) for x in stride]),
2368 delim.join([str(x) for x in kern]),
2369 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002370 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002371 args_dict = {
2372 "stride": stride,
2373 "pad": pad,
2374 "kernel": kern,
2375 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002376 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002377 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2378 }
James Ward8b390432022-08-12 20:48:56 +01002379
2380 if accum is not None:
2381 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002382 args_dict["acc_type"] = accum
2383 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002384
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002385 n = 0
James Ward8b390432022-08-12 20:48:56 +01002386 for a in accum_dtypes:
2387 for s in sorted(list(strides)):
2388 for p in sorted(list(paddings)):
2389 for k in sorted(list(kernels)):
2390 if error_name in [
2391 ErrorIf.StrideSmallerOne,
2392 ErrorIf.KernelSmallerOne,
2393 ErrorIf.PadSmallerZero,
2394 ErrorIf.PadLargerEqualKernel,
2395 ]:
2396 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
2397 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002398 )
James Ward8b390432022-08-12 20:48:56 +01002399 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002400 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002401 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002402 )
James Ward8b390432022-08-12 20:48:56 +01002403 elif (
2404 n % sparsity == 0
2405 # padding must not exceed the kernel size
2406 and p[0] < k[0]
2407 and p[1] < k[0]
2408 and p[2] < k[1]
2409 and p[3] < k[1]
2410 # the padded shape must exceed the kernel size
2411 and (shape[1] + p[0] + p[1]) > k[0]
2412 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002413 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002414 partial_h = shape[1] + p[0] + p[1] - k[0]
2415 partial_w = shape[2] + p[2] + p[3] - k[1]
2416 remainder_h = partial_h % s[0]
2417 remainder_w = partial_w % s[1]
2418 output_h = partial_h // s[0] + 1
2419 output_w = partial_w // s[1] + 1
2420 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01002421 if (
2422 # the parameters must produce integer exact output
2423 error_name != ErrorIf.PoolingOutputShapeNonInteger
2424 and remainder_h == 0
2425 and remainder_w == 0
2426 ) or (
2427 error_name == ErrorIf.PoolingOutputShapeNonInteger
2428 and (remainder_h != 0 or remainder_w != 0)
2429 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002430 if (
2431 max_dim_size is not None
2432 and max(output_h, output_w) > max_dim_size
2433 ):
2434 # Test will consume too much memory - skip it
2435 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002436 # Dot products = N*OH*OW*C
2437 dp = gtu.product(
2438 (shape[0], output_h, output_w, shape[3])
2439 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002440 arg_list.append(
2441 get_arg_list_element(a, s, p, k, dp, shape)
2442 )
James Ward8b390432022-08-12 20:48:56 +01002443 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002444
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002445 # Now add data generator types
2446 arg_list = TosaArgGen._add_data_generators(
2447 testGen,
2448 opName,
2449 dtype,
2450 arg_list,
2451 error_name,
2452 )
2453
2454 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002455 return arg_list
2456
2457 @staticmethod
2458 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
2459 arg_list = []
2460
2461 # Enumerate the output types here
2462 if error_name == ErrorIf.WrongOutputType:
2463 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
2464 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002465 dtypeList = [
2466 DType.BOOL,
2467 DType.INT16,
2468 DType.INT32,
2469 DType.FP16,
2470 DType.BF16,
2471 DType.FP32,
2472 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002473 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002474 dtypeList = [
2475 DType.BOOL,
2476 DType.INT8,
2477 DType.INT32,
2478 DType.FP16,
2479 DType.BF16,
2480 DType.FP32,
2481 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002482 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002483 dtypeList = [
2484 DType.BOOL,
2485 DType.INT8,
2486 DType.INT16,
2487 DType.FP16,
2488 DType.BF16,
2489 DType.FP32,
2490 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002491 elif inDtype == DType.BOOL:
2492 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002493 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00002494 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002495 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00002496 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002497 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00002498 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002499 elif error_name == ErrorIf.WrongInputType:
2500 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002501 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002502 else:
2503 raise Exception("Unexpected input dtype: {}".format(inDtype))
2504
2505 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002506 arg_list.append(
2507 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2508 )
2509
2510 # Now add data generator types
2511 arg_list = TosaArgGen._add_data_generators(
2512 testGen,
2513 opName,
2514 dtype,
2515 arg_list,
2516 error_name,
2517 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002518
2519 return arg_list
2520
2521 @staticmethod
2522 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
2523 arg_list = []
2524
2525 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002526 for outDtype in [
2527 DType.UINT8,
2528 DType.INT8,
2529 DType.INT16,
2530 DType.INT32,
2531 DType.UINT16,
2532 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002533 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002534 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002535 and error_name == ErrorIf.OutputZeroPointNotZero
2536 ):
2537 continue
2538 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002539 outDtype != DType.UINT16
2540 and error_name == ErrorIf.U16OutputZeroPointNotValid
2541 ) or (
2542 inDtype != DType.UINT16
2543 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002544 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002545 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002546 continue
2547 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002548 inDtype == DType.UINT8
2549 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002550 and error_name != ErrorIf.WrongOutputType
2551 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002552 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2553 continue
2554 if (
2555 inDtype not in [DType.INT8, DType.INT16]
2556 and outDtype == DType.UINT8
2557 and error_name != ErrorIf.WrongOutputType
2558 ):
2559 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2560 continue
2561 if (
2562 inDtype == DType.UINT16
2563 and outDtype != DType.INT16
2564 and error_name != ErrorIf.WrongOutputType
2565 ):
2566 # The only output dtype for UINT16 is INT16, skip all others
2567 continue
2568 if (
2569 inDtype != DType.INT16
2570 and outDtype == DType.UINT16
2571 and error_name != ErrorIf.WrongOutputType
2572 ):
2573 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002574 continue
2575 if (
2576 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002577 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002578 ):
2579 continue
2580
2581 for scale32 in [False, True]:
2582 if error_name == ErrorIf.ScaleTrue and not scale32:
2583 continue
2584 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2585 continue
2586 for double_round in [False, True]:
2587 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2588 continue
2589 for per_channel in [False, True]:
2590
2591 if (
2592 inDtype == DType.INT48
2593 and scale32
2594 and error_name != ErrorIf.ScaleTrue
2595 ):
2596 # Illegal condition. Must be scale32=False
2597 continue
2598 if (
2599 double_round
2600 and not scale32
2601 and error_name != ErrorIf.ScaleNotTrue
2602 ):
2603 # Illegal condition. ERROR_IF(!scale32 && double_round)
2604 continue
2605
2606 arg_list.append(
2607 (
2608 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002609 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002610 int(scale32),
2611 int(double_round),
2612 int(per_channel),
2613 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002614 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002615 )
2616 )
2617
2618 return arg_list
2619
2620 @staticmethod
2621 def agMul(testGen, opName, shapeList, dtype, error_name=None):
2622 arg_list = []
2623
2624 if dtype is DType.INT32:
2625 for p in range(testGen.args.num_rand_permutations):
2626
2627 shift = testGen.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002628 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002629 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002630 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002631
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002632 arg_list = TosaArgGen._add_data_generators(
2633 testGen,
2634 opName,
2635 dtype,
2636 arg_list,
2637 error_name,
2638 )
2639 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002640 return arg_list
2641
2642 @staticmethod
2643 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
2644 arg_list = []
2645
2646 arg_list.append(("roundTrue", [True]))
2647 arg_list.append(("roundFalse", [False]))
2648
2649 return arg_list
2650
Luke Hutton57287132023-02-06 14:54:18 +00002651 @staticmethod
2652 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2653 arg_list = []
2654
2655 arg_list.append(("inverseTrue", [True]))
2656 arg_list.append(("inverseFalse", [False]))
2657
2658 return arg_list
2659
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002660 # Helper function for reshape. Gets some factors of a larger number.
2661 @staticmethod
2662 def getFactors(val, start=1):
2663 factors = []
2664
2665 for i in range(start, int(np.sqrt(val)) + 1):
2666 if (val % i) == 0:
2667 factors.append(i)
2668
2669 return factors
2670
2671 @staticmethod
2672 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
2673 arg_list = []
2674
2675 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002676 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002677 factors = TosaArgGen.getFactors(totalElements)
2678
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002679 # Find new shapes up to the number of permutations asked for
2680 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002681 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002682 # Rank from 1 to TOSA_TENSOR_MAX_RANK
2683 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002684 if len(factors) < newRank:
2685 continue
2686
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002687 # escape_counter limits the generation of new shapes to a reasonable time
2688 for escape_counter in range(100):
2689
2690 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002691 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002692 remainingElements = totalElements
2693 shuffledFactors = testGen.rng.permutation(factors)
2694 for i in range(1, newRank):
2695 # pick rank-1 factors
2696 newShape.append(shuffledFactors[0])
2697 remainingElements = remainingElements // shuffledFactors[0]
2698 shuffledFactors = testGen.rng.permutation(
2699 TosaArgGen.getFactors(remainingElements)
2700 )
2701 newShape.append(remainingElements)
2702
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002703 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002704 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002705 for name, args_dict in arg_list:
2706 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002707 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002708 break
2709
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002710 if not duplicate:
2711 outShape = "x".join([str(x) for x in newShape])
2712 arg_list.append(
2713 (
2714 "perm{}_rank{}_out{}".format(p, newRank, outShape),
2715 {"new_shape": newShape},
2716 )
2717 )
2718 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002719 break
2720
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002721 # Now add data generator types
2722 arg_list = TosaArgGen._add_data_generators(
2723 testGen,
2724 opName,
2725 dtype,
2726 arg_list,
2727 error_name,
2728 )
2729
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002730 return arg_list
2731
2732 @staticmethod
2733 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
2734 arg_list = []
2735
2736 ifm_shape = shapeList[0]
2737
2738 if error_name == ErrorIf.IndexOutsideBounds:
2739 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
2740 incorrect_small_index = range(-len(ifm_shape), 0)
2741 permutations = [p for p in itertools.permutations(incorrect_large_index)]
2742 permutations.extend(
2743 [p for p in itertools.permutations(incorrect_small_index)]
2744 )
2745 elif error_name == ErrorIf.IndexUsedTwice:
2746 # Create list with a duplicated index
2747 perm_range = list(range(len(ifm_shape)))
2748 index_choice = testGen.rng.choice(range(len(perm_range)))
2749 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
2750 permutations = [p for p in itertools.permutations(perm_range)]
2751
2752 else:
2753 # Get all permutations
2754 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
2755
2756 # Limit to possible permutations from shape dimension or argument setting
2757 limit = min(len(permutations), testGen.args.num_rand_permutations)
2758
2759 # Get random permutation generator that uses all permutations
2760 random_permutations = testGen.rng.permutation(permutations)
2761
2762 # Create list of required amount of permutations
2763 arg_list = [
2764 ("perm{}".format(p), [random_permutations[p].tolist()])
2765 for p in range(limit)
2766 ]
2767 return arg_list
2768
2769 @staticmethod
2770 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
2771 arg_list = []
2772
2773 ifm_shape = shapeList[0]
2774 rank = len(ifm_shape)
2775
2776 for p in range(testGen.args.num_rand_permutations):
2777 start = []
2778 size = []
2779
2780 valid = True
2781
2782 for i in range(rank):
2783 if ifm_shape[i] > 1:
2784 start.append(testGen.randInt(0, ifm_shape[i]))
2785 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
2786
2787 # Invalid slice size?
2788 if size[i] == 0:
2789 valid = False
2790 else:
2791 start.append(0)
2792 size.append(1)
2793
2794 if valid:
2795 # If ERROR_IF test required then incorrect start, size will be returned
2796 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
2797 testGen, error_name, ifm_shape, start, size
2798 )
2799 arg_list.append(("perm{}".format(p), [start, size]))
2800 return arg_list
2801
2802 @staticmethod
2803 def agTile(testGen, opName, shapeList, dtype, error_name=None):
2804 arg_list = []
2805
2806 ifm_shape = shapeList[0]
2807 rank = len(ifm_shape)
2808
2809 for p in range(testGen.args.num_rand_permutations):
2810
2811 # Pick a few random, but small multiple values
2812 # because otherwise this has a tendency to generate
2813 # enormous tensors
2814 multiples = []
2815 for i in range(rank):
2816 if ifm_shape[i] > 1000:
2817 # Multiple of 1 if ifm_shape dimension is large to reduce
2818 # tensor size
2819 multiples.append(1)
2820 elif max(ifm_shape) > 1000:
2821 multiples.append(2)
2822 else:
2823 multiples.append(testGen.randInt(1, 4))
2824 arg_list.append(("perm{}".format(p), [multiples]))
2825
2826 return arg_list
2827
2828 @staticmethod
2829 def agResize(testGen, opName, shapeList, dtype, error_name=None):
2830 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002831 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002832
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002833 def get_aspect_ratio_resize_params():
2834 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
2835 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
2836 invert = testGen.rng.choice((False, True))
2837 letterbox = testGen.rng.choice((False, True))
2838
2839 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
2840 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
2841 scale_y_d = scale_x_d = 1
2842 offset_x = offset_y = 0
2843
2844 if letterbox:
2845 max_border = scale_y_n
2846 border_y = testGen.randInt(low=0, high=max_border)
2847 border_x = 0
2848 else:
2849 # Pillarboxing
2850 border_y = 0
2851 max_border = scale_x_n
2852 border_x = testGen.randInt(low=0, high=max_border)
2853
2854 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2855 offset = (offset_y, offset_x)
2856 border = (border_y, border_x)
2857
2858 return scale, offset, border
2859
2860 def get_upscale_downscale_params():
2861 valid_params = False
2862 while not valid_params:
2863 upscale = testGen.rng.choice((False, True))
2864
2865 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
2866 origin_sampling = testGen.rng.choice((False, True))
2867
2868 if upscale:
2869 shift = testGen.randInt(low=1, high=4)
2870 scale_x_d = scale_y_d = 1
2871 scale_x_n = scale_y_n = (
2872 1 << shift if origin_sampling else 2 << shift
2873 )
2874 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
2875 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
2876 else:
2877 scale_x_n = 1
2878 scale_y_n = 1
2879
2880 # Return list of valid scale_*_d values (max value 4) given input dim shape
2881 def get_valid_denom(ifm_dim):
2882 return [x for x in range(1, 5) if ifm_dim % x == 1]
2883
2884 # Generate list of valid downscale values and choose one randomly
2885 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
2886 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
2887
2888 if not valid_scale_y_ds and not valid_scale_x_ds:
2889 # Bad parameters, skip
2890 continue
2891
2892 if not valid_scale_y_ds:
2893 scale_y_d = 1
2894 else:
2895 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
2896
2897 if not valid_scale_x_ds:
2898 scale_x_d = 1
2899 else:
2900 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
2901
2902 border_x = border_y = 0
2903 offset_y = testGen.randInt(0, 16 * scale_y_n)
2904 offset_x = testGen.randInt(0, 16 * scale_x_n)
2905 valid_params = True
2906
2907 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2908 offset = (offset_y, offset_x)
2909 border = (border_y, border_x)
2910 return scale, offset, border
2911
2912 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002913 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
2914 scale = scale_n / scale_d
2915 if scale > max_scale:
2916 factor = scale / max_scale
2917 new_scale_d = math.ceil(scale_d * factor)
2918 assert scale_n / new_scale_d <= max_scale
2919 scale_d = new_scale_d
2920 return scale_d
2921
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002922 # Scale
2923 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
2924 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
2925
2926 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
2927 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
2928
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002929 scale_y_d = fix_scale_to_max_scale(
2930 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2931 )
2932 scale_x_d = fix_scale_to_max_scale(
2933 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2934 )
2935
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002936 # Offsets and border within the scale
2937 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
2938 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
2939 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
2940 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
2941
2942 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2943 offset = (offset_y, offset_x)
2944 border = (border_y, border_x)
2945 return scale, offset, border
2946
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002947 def get_level_8k_params():
2948 # Create 64x scale - 64/1 to 2048/32
2949 scale_d = testGen.randInt(
2950 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
2951 )
2952 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
2953 # Create half to fifth scaling
2954 scale_d_alt = testGen.randInt(low=2, high=6)
2955 scale_n_alt = 1
2956 switch = testGen.rng.choice((False, True))
2957 if switch:
2958 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
2959 else:
2960 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
2961
2962 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
2963 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
2964 offset = (offset_y, offset_x)
2965 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
2966 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
2967 border = (border_y, border_x)
2968 return scale, offset, border
2969
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002970 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002971 # Exclude illegal {mode, type} configurations. Pick legal output types
2972 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
2973 outputDTypeList = [DType.INT8]
2974 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
2975 outputDTypeList = [DType.INT16]
2976 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
2977 outputDTypeList = [DType.INT32]
2978 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
2979 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01002980 elif dtype == DType.FP16:
2981 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01002982 elif dtype == DType.BF16:
2983 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002984 elif dtype == DType.FP32:
2985 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002986 elif error_name == ErrorIf.WrongInputType:
2987 # If an incorrect input type is used then we set a 'correct'
2988 # output type to avoid other errors
2989 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
2990 else:
2991 continue
2992
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002993 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
2994
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002995 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002996 perm = 0
2997 while perm < testGen.args.num_rand_permutations:
2998 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002999 if not testGen.args.level8k:
3000 _rnd_param_fn = testGen.rng.choice(
3001 (
3002 get_rand_params,
3003 get_upscale_downscale_params,
3004 get_aspect_ratio_resize_params,
3005 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003006 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003007 scale, offset, border = _rnd_param_fn()
3008 else:
3009 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003010
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003011 # Expand params for bounds-checking
3012 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3013 (offset_y, offset_x) = offset
3014 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003015
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003016 # Make sure output dimensions OH and OW are integers
3017 partial_output_y = (
3018 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3019 )
3020 partial_output_x = (
3021 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3022 )
3023 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003024 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003025 if (
3026 partial_output_y % scale_y_d == 0
3027 and partial_output_x % scale_x_d == 0
3028 ):
3029 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003030 if perm > 0:
3031 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003032 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003033 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003034 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003035 while partial_output_y % scale_y_d != 0:
3036 scale_y_d -= 1
3037 while partial_output_x % scale_x_d != 0:
3038 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003039 # Make sure we are still within max scaling
3040 if (
3041 scale_y_n / scale_y_d
3042 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3043 scale_x_n / scale_x_d
3044 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3045 # Skip the test as it is using too large a scaling factor
3046 if perm > 0:
3047 perm += 1
3048 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003049
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003050 output_y = partial_output_y // scale_y_d + 1
3051 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003052
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003053 if (
3054 output_y >= testGen.args.max_resize_output_dim
3055 or output_x >= testGen.args.max_resize_output_dim
3056 ) and error_name is None:
3057 # Skip positive test if output dim will be too high
3058 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003059 if not testGen.args.level8k or perm > 0:
3060 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003061 continue
3062
3063 if (
3064 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003065 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003066 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003067 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003068 ):
3069 # Output dimensions out of scope
3070 if error_name is not None and perm > 0:
3071 # As long as we have one ERROR_IF test, don't worry
3072 # about creating all the other permutations
3073 perm += 1
3074 continue
3075
3076 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3077 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003078 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003079 and output_y - scale_y_d < 1
3080 )
3081 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003082 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003083 and output_x - scale_x_d < 1
3084 )
3085 ):
3086 # Can't create a negative test with these params as it
3087 # will create invalid output size
3088 if perm > 0:
3089 perm += 1
3090 continue
3091
3092 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3093 offset = [offset_y, offset_x]
3094 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003095
3096 # Common for all data types
3097 if error_name is not None:
3098 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003099 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003100 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003101 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003102 outputDTypeNew,
3103 ) = TosaErrorIfArgGen.eiResizeErrorIf(
3104 testGen,
3105 error_name,
3106 mode,
3107 dtype,
3108 shapeList,
3109 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003110 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003111 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003112 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003113 )
3114 else:
3115 outputDTypeNew = outputDType
3116
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003117 arg_to_append = (
3118 arg_str.format(
3119 "N" if mode == ResizeMode.NEAREST else "B",
3120 testGen.typeStr(outputDTypeNew),
3121 scale[0],
3122 scale[1],
3123 scale[2],
3124 scale[3],
3125 offset[0],
3126 offset[1],
3127 border[0],
3128 border[1],
3129 ),
3130 [
3131 mode,
3132 scale,
3133 offset,
3134 border,
3135 dtype,
3136 outputDTypeNew,
3137 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003138 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003139 if arg_to_append in arg_list:
3140 # Skip already generated test params
3141 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003142
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003143 # Valid permutation
3144 perm += 1
3145 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003146 return arg_list
3147
3148 @staticmethod
3149 def agTable(testGen, opName, shapeList, dtype, error_name=None):
3150 arg_list = []
3151
3152 if dtype == DType.INT8:
3153 table = np.int32(
3154 testGen.rng.integers(low=-128, high=128, size=[256])
3155 ).tolist()
3156 else: # INT16
3157 table = np.int32(
3158 testGen.rng.integers(low=-32768, high=32768, size=[513])
3159 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003160 # Make sure all slopes are within REQUIRE min/max 16-bit int
3161 for idx in range(len(table) - 1):
3162 slope = table[idx + 1] - table[idx]
3163 # Alter the next table entry to force the slope to be ok
3164 if slope > 32767:
3165 table[idx + 1] -= slope - 32767
3166 if slope < -32768:
3167 table[idx + 1] -= slope + 32768
3168 slope = table[idx + 1] - table[idx]
3169 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003170 arg_list.append(
3171 (
3172 "",
3173 [table],
3174 )
3175 )
3176 return arg_list
3177
3178 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
3179 # CondIf generates the condition values here.
3180 # Convert to tensors in the build function, along with the
3181 # then and else blocks
3182 arg_list = []
3183
3184 for c in [False, True]:
3185 arg_list.append(("cond{}".format(int(c)), [c]))
3186
3187 return arg_list
3188
3189 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
3190 # While loop: 0 iterations, 1, more than 1
3191 arg_list = []
3192
3193 for iter in [0, 1, 4]:
3194 arg_list.append(("iter{}".format(iter), [iter]))
3195
3196 return arg_list