blob: e6d7df8fc01d5f247b4c21967c1f4a02e58ba358 [file] [log] [blame]
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001# Copyright (c) 2021-2024, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
19
20class TosaQuantGen:
21 """QuantizedInfo random generator helper functions.
22
23 Specify with 'qgen': in the operator defintion.
24 """
25
26 def __init__(self):
27 pass
28
29 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010031
32 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010033 if testGen.args.zeropoint is not None:
34 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035 return testGen.randInt(-128, 128)
36 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010037 if testGen.args.zeropoint is not None:
38 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 return testGen.randInt(0, 256)
40 elif error_name in [
41 ErrorIf.InputZeroPointNotZero,
42 ErrorIf.WeightZeroPointNotZero,
43 ErrorIf.OutputZeroPointNotZero,
44 ]:
45 zero_point = testGen.randInt(-128, 128)
46 if zero_point == 0:
47 zero_point = 1
48 return zero_point
49 return 0
50
51 @staticmethod
52 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010053 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000054 qinfo = [
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 TosaQuantGen.getZeroPoint(testGen, dtype),
57 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010058 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000059 qinfo = [
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
62 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010063 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000064 qinfo = [
65 TosaQuantGen.getZeroPoint(testGen, dtype),
66 TosaQuantGen.getZeroPoint(testGen, dtype),
67 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068 return qinfo
69
70 @staticmethod
71 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010072 if isinstance(dtype_or_dtypeList, list):
73 # a list of [input, weights, accumulator] dtypes
74 dtypeList = dtype_or_dtypeList
75 else:
76 # an int, [input, weights, accumulator] dtypes are the same
77 dtypeList = [dtype_or_dtypeList] * 3
78
79 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000080 qinfo = [
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
82 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
83 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000085 qinfo = [
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
87 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
88 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000090 qinfo = [
91 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
92 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
93 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 return qinfo
95
96 @staticmethod
97 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000099 qinfo = [
100 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
101 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
102 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000104 qinfo = [
105 TosaQuantGen.getZeroPoint(testGen, dtype),
106 TosaQuantGen.getZeroPoint(testGen, dtype),
107 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108 return qinfo
109
110 @staticmethod
111 def computeMultiplierAndShift(scaleFp, scale32):
112 # Derived from computeMultiplierAndShiftTosaScale32
113 # Provide a floating-point scaling factor and the scale32 parameter
114 # to compute the multiplier and shift
115
116 if scale32:
117 scaleBits = 31
118 else:
119 scaleBits = 15
120
121 m, shift = math.frexp(scaleFp)
122
123 if scaleFp < 0.0:
124 m = -m
125
126 multiplier = round(m * (1 << scaleBits))
127 assert multiplier <= (1 << scaleBits)
128
129 if multiplier == (1 << scaleBits):
130 multiplier = multiplier // 2
131 shift = shift + 1
132
133 shift = (-shift) + scaleBits
134 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
135 # scaleFp, scaleBits, m, multiplier, shift))
136
137 # Adjust multiplier such that shift is in allowed value range.
138 if shift == 0:
139 multiplier = multiplier // 4
140 shift = shift + 2
141 elif shift == 1:
142 multiplier = multiplier // 2
143 shift = shift + 1
144 elif shift == 63:
145 multiplier = multiplier * 2
146 shift = shift - 1
147
148 assert multiplier <= (1 << scaleBits)
149 assert shift >= 2 and shift <= 62
150
151 return multiplier, shift
152
153
154class TosaTensorGen:
155 """Tensor generators create a shape list for the placeholder and const tensor
156 data operands for the operator.
157
158 The actual random data is generated separately for each test.
159 """
160
161 def __init__(self):
162 pass
163
164 @staticmethod
165 def tgBasic(testGen, opName, rank, error_name=None):
166 pl, const = opName["operands"]
167 shape = testGen.makeShape(rank)
168
169 # Constrict the overall size of the shape when creating ERROR_IF tests
170 if error_name:
171 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
172
173 shape_list = []
174 for i in range(pl + const):
175 shape_list.append(shape.copy())
176
Luke Huttona4e48ca2023-02-22 11:53:48 +0000177 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 if error_name == ErrorIf.RankMismatch:
179 if rank == 1 and i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
181 elif i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
183
184 return shape_list
185
186 @staticmethod
187 def tgNHWC(testGen, opName, rank, error_name=None):
188 pl, const = opName["operands"]
189
190 if error_name != ErrorIf.WrongRank:
191 assert rank == 4
192
193 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000194 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name and error_name != ErrorIf.MaxDimExceeded:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000207 def tgGather(testGen, opName, rank, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100208 pl, const = opName["operands"]
209
210 assert pl == 2
211 assert const == 0
212 if error_name != ErrorIf.WrongRank:
213 assert rank == 3
214
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000215 values_shape = testGen.makeShape(rank)
216 values_shape = testGen.constrictBatchSize(values_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100217
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000218 N = values_shape[0]
219 W = testGen.makeDimension()
220 indices_shape = [N, W]
221
222 shape_list = [values_shape, indices_shape]
223 return shape_list
224
225 @staticmethod
226 def tgScatter(testGen, opName, rank, error_name=None):
227 pl, const = opName["operands"]
228
229 assert pl == 3
230 assert const == 0
231 if error_name != ErrorIf.WrongRank:
232 assert rank == 3
233
234 values_in_shape = testGen.makeShape(rank)
235 values_in_shape = testGen.constrictBatchSize(values_in_shape)
236
237 N = values_in_shape[0]
238 K = values_in_shape[1]
239 C = values_in_shape[2]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100240
Jeremy Johnson194fe312023-12-07 14:17:57 +0000241 # Make sure W is not greater than K, as we can only write each output index
242 # once (having a W greater than K means that you have to repeat a K index)
243 W_min = min(testGen.args.tensor_shape_range[0], K)
244 W_max = min(testGen.args.tensor_shape_range[1], K)
245 W = testGen.randInt(W_min, W_max) if W_min < W_max else W_min
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100246
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000247 input_shape = [N, W, C]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100248
249 shape_list = []
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000250 shape_list.append(values_in_shape)
251 shape_list.append([N, W]) # indices
252 shape_list.append(input_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100253
254 return shape_list
255
256 @staticmethod
257 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
258 shape = testGen.makeShape(rank)
259
260 pl, const = op["operands"]
261
262 shape_list = []
263
264 # Choose one of the inputs to broadcast
265 # Note: Simplifies OutputShaper code if we don't change first shape for errors
266 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Jerry Ge135c9552023-05-23 20:59:32 +0000267 fuzz_idx = testGen.randInt(0, rank)
268
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100269 for i in range(pl + const):
270 shape_bcast = shape.copy()
271
Jerry Ge135c9552023-05-23 20:59:32 +0000272 # To test broadcasting, the chosen fuzz index dimension should not be 1
273 if shape_bcast[fuzz_idx] == 1:
274 shape_bcast[fuzz_idx] += 1
275
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100276 # If the chosen input, pick a random index to broadcast
277 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000278 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100279 # Add one rank to the shape (or more for rank of 1)
280 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
281 shape_bcast = np.concatenate(
282 (shape_bcast, testGen.makeShape(extra_ranks))
283 )
284 if rank != 1:
285 # Either keep the extra rank, or remove it
286 new_len = testGen.rng.choice([-2, len(shape_bcast)])
287 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000288 elif error_name == ErrorIf.BroadcastShapesMismatch:
289 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100290 else:
291 shape_bcast[fuzz_idx] = 1
292
293 shape_list.append(shape_bcast)
294
295 return shape_list
296
297 @staticmethod
298 def tgConv2D(testGen, op, rank, error_name=None):
299 pl, const = op["operands"]
300
301 if error_name != ErrorIf.WrongRank:
302 assert rank == 4
303
304 # IFM dimensions are NHWC
305 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000306 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100307
308 # Constrict the overall size of the shape when creating ERROR_IF tests
309 if error_name:
310 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
311 ifm_shape, max_dim=24, max_items=10000
312 )
313
314 # Get the filter height/width from the operator parameters
315 filter_hw = op["filter"]
316
317 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000318 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100319
320 # The filter dimensions are OHWI
321 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
322
323 # The bias is OC
324 bias_shape = np.asarray([ofm_depth])
325
326 return [ifm_shape, filter_shape, bias_shape]
327
328 @staticmethod
329 def tgConv3D(testGen, op, rank, error_name=None):
330 pl, const = op["operands"]
331
332 if error_name != ErrorIf.WrongRank:
333 assert rank == 5
334
335 # IFM dimensions are NDHWC
336 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000337 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100338
339 # Constrict the overall size of the shape when creating ERROR_IF tests
340 if error_name:
341 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
342 ifm_shape, max_dim=24, max_items=10000
343 )
344
345 # Get the filter depth/height/width from the operator parameters
346 filter_dhw = op["filter"]
347
348 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000349 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100350
351 # The filter dimensions are ODHWI
352 filter_shape = np.asarray(
353 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
354 )
355
356 # The bias is OC
357 bias_shape = np.asarray([ofm_channel])
358
359 return [ifm_shape, filter_shape, bias_shape]
360
361 @staticmethod
362 def tgTransposeConv2D(testGen, op, rank, error_name=None):
363 pl, const = op["operands"]
364
365 if error_name != ErrorIf.WrongRank:
366 assert rank == 4
367
368 # IFM dimensions are NHWC
369 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000370 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100371
372 # Constrict the overall size of the shape when creating ERROR_IF tests
373 if error_name:
374 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
375 ifm_shape, max_dim=24, max_items=10000
376 )
377
378 # Get the filter height/width from the operator parameters
379 filter_hw = op["filter"]
380
381 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000382 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100383
384 # The filter dimensions are OHWI
385 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
386
387 # The bias is OC
388 bias_shape = np.asarray([ofm_depth])
389
390 return [ifm_shape, filter_shape, bias_shape]
391
392 @staticmethod
393 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
394 pl, const = op["operands"]
395
396 if error_name != ErrorIf.WrongRank:
397 assert rank == 4
398 assert pl == 1 and const == 2
399
400 # IFM dimensions are NHWC
401 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000402 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100403
404 # Constrict the overall size of the shape when creating ERROR_IF tests
405 if error_name:
406 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
407 ifm_shape, max_dim=24, max_items=10000
408 )
409
410 # Get the filter height/width from the operator parameters
411 # Filter is KH, HW, C, M
412 filter_hw = op["filter"]
413
414 # Generate a random OFM depth, but don't let it get too big because
415 # the output depth is M * C
416 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000417 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100418 ) + 1
419
420 # The filter dimensions are HWCM
421 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
422
423 # The bias is M * C
424 bias_shape = np.asarray([ifm_shape[3] * filter_m])
425
426 return [ifm_shape, filter_shape, bias_shape]
427
428 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000429 def tgFFT2d(testGen, op, rank, error_name=None):
430 pl, const = op["operands"]
431
432 if error_name != ErrorIf.WrongRank:
433 assert rank == 3
434 assert pl == 2 and const == 0
435
436 # IFM dimensions are NHW
437 ifm_shape = testGen.makeShape(rank)
438
439 # Select nearest lower power of two from input height and width
440 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
441 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
442
443 # Constrict the overall size of the shape when creating ERROR_IF tests
444 if error_name:
445 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
446
447 # Generate an invalid kernel that is not a power of two
448 if error_name == ErrorIf.KernelNotPowerOfTwo:
449 inc_h = 2 if ifm_shape[1] == 1 else 1
450 inc_w = 2 if ifm_shape[2] == 1 else 1
451 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
452 selected_inc = testGen.rng.choice(inc_choices)
453 ifm_shape[1] += selected_inc[0]
454 ifm_shape[2] += selected_inc[1]
455
456 ifm_shape = testGen.constrictBatchSize(ifm_shape)
457
458 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
459 if error_name == ErrorIf.FFTInputShapeMismatch:
460 modify_shape = testGen.rng.choice([0, 1])
461 # Only modify kernel (H, W)
462 modify_dim = testGen.rng.choice([1, 2])
463 ifm_shapes[modify_shape][modify_dim] *= 2
464
465 return [ifm_shapes[0], ifm_shapes[1]]
466
467 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000468 def tgRFFT2d(testGen, op, rank, error_name=None):
469 pl, const = op["operands"]
470
471 if error_name != ErrorIf.WrongRank:
472 assert rank == 3
473 assert pl == 1 and const == 0
474
475 # IFM dimensions are NHW
476 ifm_shape = testGen.makeShape(rank)
477
478 # Select nearest lower power of two from input height and width
479 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
480 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
481
482 # Constrict the overall size of the shape when creating ERROR_IF tests
483 if error_name:
484 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
485
486 # Generate an invalid kernel that is not a power of two
487 if error_name == ErrorIf.KernelNotPowerOfTwo:
488 # We must increment by 2 if current size is 1
489 inc_h = 2 if ifm_shape[1] == 1 else 1
490 inc_w = 2 if ifm_shape[2] == 1 else 1
491 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
492 selected_inc = testGen.rng.choice(inc_choices)
493 ifm_shape[1] += selected_inc[0]
494 ifm_shape[2] += selected_inc[1]
495
James Ward30124a82023-02-02 14:56:33 +0000496 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000497
498 return [ifm_shape]
499
500 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100501 def tgFullyConnected(testGen, op, rank, error_name=None):
502 pl, const = op["operands"]
503
504 if error_name != ErrorIf.WrongRank:
505 assert rank == 2
506
507 input_shape = testGen.makeShape(rank)
508
509 # Constrict the overall size of the shape when creating ERROR_IF tests
510 if error_name:
511 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
512
513 filter_oc = testGen.rng.integers(
514 low=testGen.args.tensor_shape_range[0],
515 high=testGen.args.tensor_shape_range[1],
516 size=1,
517 )[0]
518 filter_shape = np.asarray([filter_oc, input_shape[1]])
519
520 bias_shape = np.asarray([filter_oc])
521
522 return [input_shape, filter_shape, bias_shape]
523
524 @staticmethod
525 def tgMatmul(testGen, op, rank, error_name=None):
526 pl, const = op["operands"]
527
528 if error_name != ErrorIf.WrongRank:
529 assert rank == 3
530 assert pl == 2 and const == 0
531
532 a_shape = testGen.makeShape(rank)
533
534 # Constrict the overall size of the shape when creating ERROR_IF tests
535 if error_name:
536 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
537
538 # Get a random number for b_oc even if target shape is defined
539 b_oc = np.int32(
540 testGen.rng.integers(
541 low=testGen.args.tensor_shape_range[0],
542 high=testGen.args.tensor_shape_range[1],
543 size=1,
544 )
545 )[0]
546 # If N or H is large let b_oc be 1 to reduce output tensor size
547 if max(a_shape) > 1000:
548 b_oc = 1
549
550 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
551 return [a_shape, b_shape]
552
553 @staticmethod
554 def tgConcat(testGen, opName, rank, error_name=None):
555 pl, const = opName["operands"]
556 shape = testGen.makeShape(rank)
557
558 # Create extra tensors to concat.
559 # Take into account value of pl when getting maximum number of concats
560 num_tensors = testGen.randInt(0, 4)
561 shape_list = []
562 for i in range(pl + const + num_tensors):
563 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
564 remove = testGen.rng.choice([True, False])
565 wrongShape = shape.copy()
566
567 if remove and len(shape) > 1:
568 wrongShape = wrongShape[1:]
569 else:
570 wrongShape = list(wrongShape)
571 wrongShape.append(testGen.rng.integers(1, 10))
572
573 shape_list.append(wrongShape)
574 else:
575 shape_list.append(shape.copy())
576
577 return shape_list
578
579 @staticmethod
580 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
581 if error_name in [
582 ErrorIf.AxisSmallerZero,
583 ErrorIf.AxisLargerRank,
584 ErrorIf.ConcatInputRankMismatch,
585 ]:
586 return shapeList
587
588 # Split concat shape along axis to allow for multiple const inputs
589 # without making too many large tensors
590 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
591 # If axis can't be split we still need to invalidate other dimensions
592 if error_name == ErrorIf.ConcatInputDimMismatch:
593 for shape in shapeList[1:]:
594 # Negative test shapeLists are created individually for each test,
595 # so no need to copy the shape before altering it.
596 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
597 return shapeList
598
599 # Create copy of shape we are going to split (so we don't alter shapeList)
600 shape = shapeList[0].copy()
601 # Add original shape as first input
602 new_shapeList = [shape.copy()]
603 length_on_axis = shape[axis]
604 remaining_length = length_on_axis
605 for i in range(len(shapeList) - 2):
606 # Calculate split on axis and remaining value
607 split_shape_val = int(shape[axis] / 2)
608 remaining_length = remaining_length - split_shape_val
609
610 # Append new shape, and set remaining shape
611 shape[axis] = split_shape_val
612 new_shapeList.append(shape.copy())
613
614 # invalidate dimensions
615 if error_name == ErrorIf.ConcatInputDimMismatch:
616 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
617 else:
618 shape[axis] = remaining_length
619
620 if i == len(shapeList) - 3:
621 new_shapeList.append(shape.copy())
622
623 return new_shapeList
624
625
626class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100627 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100628
629 def __init__(self):
630 pass
631
Jeremy Johnson1271c442023-09-05 11:39:26 +0100632 class TVGInfo:
633 """Enhanced tensor values information including data gen dict."""
634
635 def __init__(self, tensorList, dataGenDict):
636 self.tensorList = tensorList
637 self.dataGenDict = dataGenDict
638
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100639 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000640 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100641 pCount, cCount = op["operands"]
642
643 tens = []
644 tens.extend(
645 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
646 )
647 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
648
649 return tens
650
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100651 # Default high value for random numbers
652 TVG_FLOAT_HIGH_VALUE = {
653 DType.FP32: (1 << 128) - (1 << (127 - 23)),
654 DType.FP16: (1 << 16) - (1 << (15 - 10)),
655 DType.BF16: (1 << 128) - (1 << (127 - 7)),
656 }
657
Jeremy Johnson30476252023-11-20 16:15:30 +0000658 # Default lowest normal values for random numbers
659 TVG_FLOAT_LOW_VALUE = {
660 DType.FP32: np.exp2(-126),
661 DType.FP16: np.exp2(-14),
662 DType.BF16: np.exp2(-126),
663 }
664
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100665 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +0000666 def _get_data_range(testGen, dtype, highValueLookup, lowValueLookup=None):
667 # Return a tuple of (low,high) data range values for the given data
668 # type using a combination of per operator table limits, data limits
669 # and user supplied ranges for FP numbers
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000670 if dtype in highValueLookup:
Jeremy Johnson30476252023-11-20 16:15:30 +0000671 type_range = testGen.getDTypeRange(dtype, high_inclusive=True)
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000672 high_val = highValueLookup[dtype]
Jeremy Johnson30476252023-11-20 16:15:30 +0000673 if lowValueLookup is not None and dtype in lowValueLookup:
674 low_val = lowValueLookup[dtype]
675 else:
676 low_val = -high_val
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000677 # Set the values to something that won't produce infinity whilst
Jeremy Johnson30476252023-11-20 16:15:30 +0000678 # respecting the default ranges if more/less than the low/high
679 # values
680 data_range = (
681 max(low_val, type_range[0]),
682 min(high_val, type_range[1]),
683 )
684 if data_range[0] > data_range[1]:
685 # Invalid data range from low to high created due to user
686 # constraints revert to using internal ranges as they are
687 # known to work
688 msg = f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
689 warnings.warn(msg)
690 data_range = (low_val, high_val)
691 return data_range
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000692 return None
693
694 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100695 def tvgLazyGenDefault(
696 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
697 ):
698 # Variable inputs versus constants
699 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
Jeremy Johnson3eafe662024-01-10 13:13:35 +0000700 if "p_count" in argsDict:
701 # Override for operators like CONCAT
702 pCount = argsDict["p_count"]
703 cCount = argsDict["c_count"]
704 assert pCount + cCount == len(
705 shapeList
706 ), "Placeholders & Constant tensors must match shapes list"
707
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000708 tens_ser_list = []
Jeremy Johnson1271c442023-09-05 11:39:26 +0100709
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100710 if (
711 error_name is not None
712 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100713 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100714 ):
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000715 # Fall back to internal data gen when dealing with unsupported types or ops
716 data_range = argsDict["data_range"] if "data_range" in argsDict else None
717 for idx, info in enumerate(zip(shapeList, dtypeList)):
Jeremy Johnson30476252023-11-20 16:15:30 +0000718 roundMode = False
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000719 shape, dtype = info
Jeremy Johnson30476252023-11-20 16:15:30 +0000720 if "data_range_list" in argsDict:
721 data_range = argsDict["data_range_list"][idx]["range"]
722 roundMode = (
723 "round" in argsDict["data_range_list"][idx]
724 and argsDict["data_range_list"][idx]["round"] is True
725 )
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000726 if data_range is not None and dtype not in (
727 DType.FP16,
728 DType.FP32,
729 DType.BF16,
730 ):
731 # Change from inclusive to exclusive range
732 data_range = (data_range[0], data_range[1] + 1)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000733 # Ignore lazy data gen option and create data array using any range limits
734 arr = testGen.getRandTensor(shape, dtype, data_range)
Jeremy Johnson30476252023-11-20 16:15:30 +0000735 if roundMode:
736 arr = np.round(arr)
Jeremy Johnson30a41db2023-11-15 11:00:49 +0000737 if idx < pCount:
738 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
739 else:
740 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100741
Jeremy Johnson1271c442023-09-05 11:39:26 +0100742 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
743
744 # Create data generator meta-data
745 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100746 tens_data = {
747 "version": "0.1",
748 "tensors": {},
749 }
750 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100751 for idx, shape in enumerate(shapeList):
752
753 tens_meta = {}
754 tens_meta["generator"] = gtu.DataGenType(dg_type).name
755 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
756 tens_meta["shape"] = [int(i) for i in shape]
757 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100758 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100759
760 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100761 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100762 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100763 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100764
765 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
766 info = {}
767 # TODO - generate seed for this generator based on test
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100768 info["rng_seed"] = 42
Jeremy Johnson30476252023-11-20 16:15:30 +0000769
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000770 data_range = None
Jeremy Johnson30476252023-11-20 16:15:30 +0000771 if "data_range_list" in argsDict:
772 data_range = argsDict["data_range_list"][idx]["range"]
773 if "round" in argsDict["data_range_list"][idx]:
774 info["round"] = argsDict["data_range_list"][idx]["round"]
775 elif "data_range" in argsDict:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100776 data_range = argsDict["data_range"]
Jeremy Johnsona8420ad2023-12-07 16:35:28 +0000777
778 if data_range is None:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100779 data_range = testGen.getDTypeRange(
780 dtypeList[idx], high_inclusive=True
781 )
782 info["range"] = [str(v) for v in data_range]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100783 tens_meta["pseudo_random_info"] = info
784 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
785 info = {}
786 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100787 info["ks"] = int(argsDict["ks"])
788 if "acc_type" in argsDict:
789 # Convert type number into JSON name
790 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
791 "json"
792 ]
793 if "kernel" in argsDict:
794 info["kernel"] = [int(k) for k in argsDict["kernel"]]
795 if "axis" in argsDict:
796 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100797 tens_meta["dot_product_info"] = info
798 else:
799 # TODO - other data gen type
800 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100801
802 # Using the finished generate config meta data - generate the data if
803 # needed and assign a tensor name from the serializer
804
805 # Need to generate data when not lazy or for the bias tensor as we need
806 # to work out if the bias data is non-zero for compliance
807 if not testGen.args.lazy_data_gen or (
808 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
809 ):
810 # Give this tensor a temporary name until we get one from the serializer
811 temp_name = f"placeholder_{idx}"
812 dg_tens_meta[temp_name] = tens_meta
813 # Create data now using the temporary name to access meta details
814 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
815 # Remove the item as we will give it the correct name later
816 del dg_tens_meta[temp_name]
817
818 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
819 # The KS value used by compliance verification is altered when the
820 # bias data is non-zero
821 if max(abs(data)) > 0.0:
822 argsDict["ksb"] = argsDict["ks"] + 1
823
824 if testGen.args.lazy_data_gen:
825 data = None
826
827 if tens_meta["input_type"] == "VARIABLE":
828 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
829 else:
830 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
831
832 tens_ser_list.append(tens)
833 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100834 dg_tens_meta[tens.name] = tens_meta
835
Jeremy Johnson1271c442023-09-05 11:39:26 +0100836 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
837
838 @staticmethod
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000839 def tvgNegate(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100840 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000841 # Integer test
842 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100843 pCount, cCount = op["operands"]
844 assert (
845 pCount == 1 and cCount == 0
846 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100847 # Must create tensors with values within accumulator (int32) negatable
848 # range
849 max_val = (1 << 31) - 1
850 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100851 arr = np.int32(
852 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
853 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000854 tens_ser_list = []
855 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100856 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
857 )
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000858 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100859 else:
Jeremy Johnson2d70ac42023-11-06 17:46:02 +0000860 # ERROR_IF or floating point test
861 return TosaTensorValuesGen.tvgLazyGenDefault(
862 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100863 )
864
Jeremy Johnson30476252023-11-20 16:15:30 +0000865 # Set the ADD/SUB data range to half the largest value to avoid infinities
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000866 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
867 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
868 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
869 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
870 }
871
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100872 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000873 def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100874 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000875 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100876 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000877 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100878 pCount, cCount = op["operands"]
879 assert (
880 pCount == 2 and cCount == 0
881 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000882 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100883 add = op["op"] == Op.ADD
884 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
885 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
886 if add:
887 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
888 else:
889 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
890
891 # Work out the saturation limits
892 max_i32 = (1 << 31) - 1
893 min_i32 = -(1 << 31)
894 max_arr = np.full(shapeList[1], max_i32)
895 min_arr = np.full(shapeList[1], min_i32)
896
897 # Find how much values exceed the maximum/minimums
898 sat_max_arr = np.maximum(res_arr - max_arr, 0)
899 sat_min_arr = np.minimum(res_arr - min_arr, 0)
900
901 if not add:
902 # Swap saturation values and negate values as we need to perform opposite operations
903 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
904
905 # Create new array of unsaturated values by clipping values as needed
906 b_unsat_arr = b_arr
907 if (sat_max_arr != 0).any():
908 # Clip values that cause saturation
909 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
910 # Reduce axes in unsaturated tensor to match original tensor
911 for axis, dim in enumerate(b_arr.shape):
912 if dim != b_unsat_arr.shape[axis]:
913 assert (
914 dim == 1
915 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
916 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
917
918 if (sat_min_arr != 0).any():
919 # Clip values that cause saturation
920 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
921 # Reduce axes in unsaturated tensor to match original tensor
922 for axis, dim in enumerate(b_arr.shape):
923 if dim != b_unsat_arr.shape[axis]:
924 assert (
925 dim == 1
926 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
927 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
928
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000929 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100930 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
931 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000932 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100933 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
934 )
935
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000936 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100937 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000938 # ERROR_IF or floating point test
939 data_range = TosaTensorValuesGen._get_data_range(
940 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
941 )
942 if data_range:
943 argsDict["data_range"] = data_range
944
945 return TosaTensorValuesGen.tvgLazyGenDefault(
946 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100947 )
948
949 @staticmethod
950 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000951 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100952 ):
953 if dtypeList[0] in (
954 DType.INT32,
955 DType.INT16,
956 DType.INT8,
957 ):
958 # Limit input tensors with cond_if_binary or while_loop to stop
959 # saturation of add/sub ops with int32 and keep all logical shift
960 # values between 0 to 31 for int16 or int8
961 pCount, cCount = op["operands"]
962 pRemain = pCount
963 placeholders = []
964 for idx, shape in enumerate(shapeList[:]):
965 if dtypeList[0] == DType.INT32:
966 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
967 else:
968 arr = np.int32(
969 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
970 )
971 if pRemain > 0:
972 placeholders.append(
973 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
974 )
975 pRemain -= 1
976 else:
977 placeholders.append(
978 testGen.ser.addConst(shape, dtypeList[idx], arr)
979 )
980
981 return placeholders
982 else:
983 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000984 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100985 )
986
987 @staticmethod
988 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000989 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100990 ):
991 pCount, cCount = op["operands"]
992 # Force value of operand[1] to be within [0, num_bits]
993 assert (
994 pCount == 2 and cCount == 0
995 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
996
997 placeholders = []
998 for idx, shape in enumerate(shapeList[:]):
999 if idx == 1:
1000 if dtypeList[idx] == DType.INT8:
1001 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
1002 elif dtypeList[idx] == DType.INT16:
1003 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
1004 elif dtypeList[idx] == DType.INT32:
1005 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
1006 elif error_name == ErrorIf.WrongInputType:
1007 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
1008 else:
1009 raise Exception("OpArithmeticRightShift: invalid input dtype")
1010 else:
1011 arr = testGen.getRandTensor(shape, dtypeList[idx])
1012 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
1013
1014 return placeholders
1015
1016 @staticmethod
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001017 def tvgSelect(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001018 # Set datatype of condition tensor to boolean
1019 dtypeList[0] = DType.BOOL
1020
Jeremy Johnson7b9abce2024-01-10 11:07:29 +00001021 return TosaTensorValuesGen.tvgLazyGenDefault(
1022 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001023 )
1024
1025 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001026 def tvgIntDiv(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001027 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001028 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001029 pCount, cCount = op["operands"]
1030 assert (
1031 pCount == 2 and cCount == 0
1032 ), "Op.INTDIV must have 2 placeholders, 0 consts"
1033
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001034 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001035
1036 # Two invalid cases for Op.INTDIV:
1037 # 1. divisor == 0
1038 # 2. dividend == -(1<<31) and divisor == -1
1039 while True:
1040 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1041 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
1042
1043 if (divisor_arr == 0).any():
1044 continue
1045
1046 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
1047 continue
1048
1049 break
1050
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001051 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001052 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
1053 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001054 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001055 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
1056 )
1057
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001058 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001059 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001060 return TosaTensorValuesGen.tvgLazyGenDefault(
1061 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001062 )
1063
Jeremy Johnson30476252023-11-20 16:15:30 +00001064 # Set the MUL data range to the square root of the largest value
1065 # to avoid infinities
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001066 TVG_FLOAT_HIGH_VALUE_MUL = {
1067 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1068 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1069 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1070 }
1071
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001072 @staticmethod
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001073 def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1074 if error_name is not None or dtypeList[0] in (
1075 DType.FP16,
1076 DType.BF16,
1077 DType.FP32,
1078 ):
1079 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001080 data_range = TosaTensorValuesGen._get_data_range(
1081 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
1082 )
1083 if data_range:
1084 argsDict["data_range"] = data_range
1085
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001086 return TosaTensorValuesGen.tvgLazyGenDefault(
1087 testGen, opName, dtypeList, shapeList, argsDict, error_name
1088 )
1089 else:
1090 # Integer test
1091 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001092 pCount, cCount = op["operands"]
1093 assert (
1094 pCount == 2 and cCount == 0
1095 ), "Op.MUL must have 2 placeholders, 0 consts"
1096
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001097 tens_ser_list = []
1098
1099 # Make sure multiply result in int32 range
1100 shift = argsDict["shift"]
1101 if dtypeList[0] == DType.INT8:
1102 num_bits = 8
1103 elif dtypeList[0] == DType.INT16:
1104 num_bits = 16
1105 elif dtypeList[0] == DType.INT32:
1106 num_bits = 32
1107 elif error_name == ErrorIf.WrongInputType:
1108 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001109 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001110 raise Exception("OpMul: invalid input dtype")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001111
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001112 for idx, shape in enumerate(shapeList[:]):
1113 low = -(2 ** (num_bits - 1))
1114 high = (2 ** (num_bits - 1)) - 1
1115
1116 a_arr = np.int32(
1117 testGen.rng.integers(low=low, high=high, size=shapeList[0])
1118 )
1119 b_arr = np.int32(
1120 testGen.rng.integers(low=low, high=high, size=shapeList[1])
1121 )
1122
1123 i = 0
1124 while True:
1125
1126 a_arr_64 = a_arr.astype(np.int64)
1127 b_arr_64 = b_arr.astype(np.int64)
1128
1129 if shift > 0:
1130 rounding = 1 << (shift - 1)
1131 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001132 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001133 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001134
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001135 if (result_arr > -(2**31)).all() and (
1136 result_arr <= ((2**31) - 1)
1137 ).all():
1138 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001139
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001140 i = i + 1
1141 a_arr = a_arr // 2
1142 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001143
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001144 tens_ser_list.append(
1145 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001146 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001147 tens_ser_list.append(
1148 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1149 )
1150
1151 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001152
1153 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001154 def tvgConcat(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001155 count = len(shapeList) - testGen.args.num_const_inputs_concat
1156 if count < 1:
1157 count = 1
1158 if testGen.args.num_const_inputs_concat == 0:
1159 count = len(shapeList)
1160
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001161 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001162 testGen, shapeList, argsDict["axis"], error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001163 )
1164
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001165 # Override default pCount/cCount for operator
1166 argsDict["p_count"] = count
1167 argsDict["c_count"] = len(shapeList) - count
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001168
Jeremy Johnson3eafe662024-01-10 13:13:35 +00001169 return TosaTensorValuesGen.tvgLazyGenDefault(
1170 testGen, opName, dtypeList, shapeList, argsDict, error_name
1171 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001172
1173 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001174 def tvgLogicalShift(
1175 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1176 ):
1177 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001178 pCount, cCount = op["operands"]
1179 assert (
1180 pCount == 2 and cCount == 0
1181 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
1182 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1183 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001184 tens_ser_list = []
1185 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001186 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1187 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001188 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001189 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1190 )
1191
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001192 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001193
1194 @staticmethod
Jeremy Johnsona0150012023-11-15 15:52:06 +00001195 def tvgEqual(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1196 if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
1197 # Integer
1198 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001199 pCount, cCount = op["operands"]
1200 assert (
1201 pCount == 2 and cCount == 0
1202 ), "Op.EQUAL must have 2 placeholders, 0 consts"
Jeremy Johnsona0150012023-11-15 15:52:06 +00001203
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001204 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1205 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
Jeremy Johnsona0150012023-11-15 15:52:06 +00001206
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001207 # Using random numbers means that it will be very unlikely that
1208 # there are any matching (equal) values, therefore force that
1209 # there are twice the number of matching values as the tensor rank
1210 for num in range(0, len(shapeList[0]) * 2):
1211 a_index = []
1212 b_index = []
1213 # Choose an index in each axis for the whole shape
1214 for axis in range(0, len(shapeList[0])):
1215 # Index can be up to the largest dimension in both shapes
1216 index = np.int32(
1217 testGen.rng.integers(
1218 0, max(shapeList[0][axis], shapeList[1][axis])
1219 )
1220 )
1221 # Reduce the index down to a shape's dim for broadcasting
1222 a_index.append(min(shapeList[0][axis] - 1, index))
1223 b_index.append(min(shapeList[1][axis] - 1, index))
1224
1225 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1226
Jeremy Johnsona0150012023-11-15 15:52:06 +00001227 tens_ser_list = []
1228 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001229 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1230 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001231 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001232 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1233 )
Jeremy Johnsona0150012023-11-15 15:52:06 +00001234 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001235 else:
Jeremy Johnsona0150012023-11-15 15:52:06 +00001236 # ERROR_IF or floating point test
1237 return TosaTensorValuesGen.tvgLazyGenDefault(
1238 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001239 )
1240
1241 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001242 def tvgReduceSum(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson30476252023-11-20 16:15:30 +00001243 dtype = dtypeList[0]
1244 if dtype == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001245 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001246 pCount, cCount = op["operands"]
1247 assert (
1248 pCount == 1 and cCount == 0
1249 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1250 # Limit values so that the sum cannot exceed the range of an int32 during
1251 # summation of any axis
1252 range_val = int((1 << 31) / max(shapeList[0]))
1253 values_arr = np.int32(
1254 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1255 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001256 tens_ser_list = []
1257 tens_ser_list.append(
Jeremy Johnson30476252023-11-20 16:15:30 +00001258 testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001259 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001260 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001261 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001262 # ERROR_IF or dot product floating point test
Jeremy Johnson30476252023-11-20 16:15:30 +00001263 if (
1264 error_name is None
1265 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
1266 ):
1267 # Limit ranges for (non error & non compliance) tests by using
1268 # values that can be summed on any axis to not hit infinity
1269 highval_lookup = {
1270 dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
1271 / max(shapeList[0])
1272 }
1273 data_range = TosaTensorValuesGen._get_data_range(
1274 testGen, dtype, highval_lookup
1275 )
1276 assert data_range is not None
1277 argsDict["data_range"] = data_range
1278
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001279 return TosaTensorValuesGen.tvgLazyGenDefault(
1280 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001281 )
1282
Jeremy Johnsonbd801962024-01-03 17:07:44 +00001283 @staticmethod
1284 def tvgReduceProduct(
1285 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1286 ):
1287 dtype = dtypeList[0]
1288 if error_name is None:
1289 # Limit ranges for (non error) tests by using
1290 # values that can be multiplied on any axis to not hit infinity
1291 highval_lookup = {
1292 dtype: math.pow(
1293 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
1294 1 / max(shapeList[0]),
1295 )
1296 }
1297 data_range = TosaTensorValuesGen._get_data_range(
1298 testGen, dtype, highval_lookup
1299 )
1300 assert data_range is not None
1301 argsDict["data_range"] = data_range
1302
1303 return TosaTensorValuesGen.tvgLazyGenDefault(
1304 testGen, opName, dtypeList, shapeList, argsDict, error_name
1305 )
1306
Jeremy Johnson30476252023-11-20 16:15:30 +00001307 # Set the POW exponent high data range
1308 TVG_FLOAT_HIGH_VALUE_POW_EXP = {
1309 DType.FP32: 10.0,
1310 DType.FP16: 10.0,
1311 DType.BF16: 10.0,
1312 }
1313 # POW highest base value (within a safe margin of error) that can be raised
1314 # to +ve exponent that doesn't become Infinity
1315 TVG_FLOAT_HIGH_VALUE_POW_BASE = {
1316 DType.FP32: math.floor(
1317 math.pow(
1318 TVG_FLOAT_HIGH_VALUE[DType.FP32],
1319 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1320 )
1321 ),
1322 DType.FP16: math.floor(
1323 math.pow(
1324 TVG_FLOAT_HIGH_VALUE[DType.FP16],
1325 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1326 )
1327 ),
1328 DType.BF16: math.floor(
1329 math.pow(
1330 TVG_FLOAT_HIGH_VALUE[DType.BF16],
1331 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1332 )
1333 ),
1334 }
1335 # POW lowest base value (within a safe margin of error) that can be raised
1336 # to -ve exponent that doesn't become Infinity
1337 TVG_FLOAT_LOW_VALUE_POW_BASE = {
1338 DType.FP32: math.ceil(
1339 math.pow(
1340 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
1341 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
1342 )
1343 * 1000
1344 )
1345 / 1000,
1346 DType.FP16: math.ceil(
1347 math.pow(
1348 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
1349 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
1350 )
1351 * 1000
1352 )
1353 / 1000,
1354 DType.BF16: math.ceil(
1355 math.pow(
1356 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
1357 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
1358 )
1359 * 1000
1360 )
1361 / 1000,
1362 }
1363
1364 @staticmethod
1365 def tvgPow(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1366 if error_name is not None:
1367 return TosaTensorValuesGen.tvgLazyGenDefault(
1368 testGen, opName, dtypeList, shapeList, argsDict, error_name
1369 )
1370 dtype = dtypeList[0]
1371 # Different ranges for POW
1372 test_set = argsDict["s"]
1373 if test_set == 0:
1374 # Positive base with fractional exponent
1375 base_range = TosaTensorValuesGen._get_data_range(
1376 testGen,
1377 dtype,
1378 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1379 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1380 )
1381 exp_range = TosaTensorValuesGen._get_data_range(
1382 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1383 )
1384 exp_round = False
1385 else:
1386 # Integer exponent
1387 exp_range = TosaTensorValuesGen._get_data_range(
1388 testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
1389 )
1390 exp_round = True
1391 if test_set == 1:
1392 # Positive base
1393 base_range = TosaTensorValuesGen._get_data_range(
1394 testGen,
1395 dtype,
1396 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
1397 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
1398 )
1399 else:
1400 assert test_set == 2
1401 # Negative base
1402 # Supply new look up tables with negative values
1403 base_range = TosaTensorValuesGen._get_data_range(
1404 testGen,
1405 dtype,
1406 {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
1407 {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
1408 )
1409
1410 data_range_list = (
1411 {
1412 "range": base_range,
1413 },
1414 {
1415 "range": exp_range,
1416 "round": exp_round,
1417 },
1418 )
1419 argsDict["data_range_list"] = data_range_list
1420 return TosaTensorValuesGen.tvgLazyGenDefault(
1421 testGen, opName, dtypeList, shapeList, argsDict, error_name
1422 )
1423
1424 @staticmethod
1425 def tvgLogRsqrt(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1426 # LOG & RSQRT data range from lowest expressible positive number to
1427 # largest to avoid NaNs
1428 data_range = TosaTensorValuesGen._get_data_range(
1429 testGen,
1430 dtypeList[0],
1431 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
1432 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
1433 )
1434 if data_range:
1435 argsDict["data_range"] = data_range
1436
1437 return TosaTensorValuesGen.tvgLazyGenDefault(
1438 testGen, opName, dtypeList, shapeList, argsDict, error_name
1439 )
1440
1441 # Set the EXP data range to the log of the largest to smallest values
1442 # to avoid infinities or making the result zero
1443 TVG_FLOAT_HIGH_VALUE_EXP = {
1444 DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
1445 DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
1446 DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
1447 }
1448 TVG_FLOAT_LOW_VALUE_EXP = {
1449 DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
1450 DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
1451 DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
1452 }
1453
1454 @staticmethod
1455 def tvgExp(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1456 data_range = TosaTensorValuesGen._get_data_range(
1457 testGen,
1458 dtypeList[0],
1459 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
1460 TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
1461 )
1462 if data_range:
1463 argsDict["data_range"] = data_range
1464
1465 return TosaTensorValuesGen.tvgLazyGenDefault(
1466 testGen, opName, dtypeList, shapeList, argsDict, error_name
1467 )
1468
1469 @staticmethod
1470 def tvgFullyConnected(
1471 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1472 ):
1473 dtype = dtypeList[0]
1474 if (
1475 error_name is None
1476 and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
Jeremy Johnson718f3472023-11-30 14:18:19 +00001477 and dtype in (DType.BF16,)
Jeremy Johnson30476252023-11-20 16:15:30 +00001478 ):
Jeremy Johnson718f3472023-11-30 14:18:19 +00001479 # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
Jeremy Johnson30476252023-11-20 16:15:30 +00001480 # Limit ranges for (non error & non compliance) FP tests by using
1481 # values that can be multiplied on any axis to not hit infinity/NaN
1482 IC = shapeList[0][1]
1483 highval_lookup = {
1484 dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
1485 }
1486 data_range = TosaTensorValuesGen._get_data_range(
1487 testGen, dtype, highval_lookup
1488 )
1489 assert data_range is not None
1490 argsDict["data_range"] = data_range
1491
1492 return TosaTensorValuesGen.tvgLazyGenDefault(
1493 testGen, opName, dtypeList, shapeList, argsDict, error_name
1494 )
1495
Jeremy Johnson708da822023-11-15 16:25:45 +00001496 @staticmethod
1497 def tvgCast(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1498 in_dtype = dtypeList[0]
1499 out_dtype = argsDict["out_type"]
1500 # Create look up to limit input tensor to output type maximums to avoid
1501 # FP infinities and saturation of integers
1502 out_range = testGen.getDTypeRange(out_dtype, high_inclusive=True)
1503 highval_lookup = {in_dtype: out_range[1]}
1504 data_range = TosaTensorValuesGen._get_data_range(
1505 testGen,
1506 in_dtype,
1507 highval_lookup,
1508 )
1509
1510 assert data_range is not None
1511 argsDict["data_range"] = data_range
1512
1513 return TosaTensorValuesGen.tvgLazyGenDefault(
1514 testGen, opName, dtypeList, shapeList, argsDict, error_name
1515 )
1516
Jeremy Johnsona8420ad2023-12-07 16:35:28 +00001517 @staticmethod
1518 def tvgGather(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1519 K = shapeList[0][1]
1520
1521 # Fix the type of the indices tensor
1522 dtypeList[1] = DType.INT32
1523
1524 dtype = dtypeList[0]
1525 if not gtu.dtypeIsSupportedByCompliance(dtype):
1526 # Test unsupported by data generator
1527 op = testGen.TOSA_OP_LIST[opName]
1528 pCount, cCount = op["operands"]
1529 assert (
1530 pCount == 2 and cCount == 0
1531 ), "Op.GATHER must have 2 placeholders, 0 consts"
1532
1533 tens_ser_list = []
1534 for idx, shape in enumerate(shapeList):
1535 dtype = dtypeList[idx]
1536 if idx != 1:
1537 arr = testGen.getRandTensor(shape, dtype)
1538 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1539 else:
1540 # Limit data range of indices tensor upto K (exclusive)
1541 arr = testGen.getRandTensor(shape, dtype, (0, K))
1542 # To match old functionality - create indices as CONST
1543 tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
1544
1545 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1546
1547 else:
1548 # ERROR_IF or floating point test
1549 # Use inclusive values upto index K for indices tensor
1550 data_range_list = (
1551 {"range": None},
1552 {"range": (0, K - 1)},
1553 )
1554 argsDict["data_range_list"] = data_range_list
1555
1556 return TosaTensorValuesGen.tvgLazyGenDefault(
1557 testGen, opName, dtypeList, shapeList, argsDict, error_name
1558 )
1559
1560 @staticmethod
1561 def tvgScatter(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1562 K = shapeList[0][1]
1563 W = shapeList[2][1]
1564
1565 # Work out an indices tensor here with data that doesn't exceed the
1566 # dimension K of the values_in tensor and does NOT repeat the same K
1567 # location as needed by the spec:
1568 # "It is not permitted to repeat the same output index within a single
1569 # SCATTER operation and so each output index occurs at most once."
1570 assert K >= W, "Op.SCATTER W must be smaller or equal to K"
1571
1572 # Fix the type of the indices tensor
1573 dtypeList[1] = DType.INT32
1574
1575 dtype = dtypeList[0]
1576 if not gtu.dtypeIsSupportedByCompliance(dtype):
1577 # Test unsupported by data generator
1578 op = testGen.TOSA_OP_LIST[opName]
1579 pCount, cCount = op["operands"]
1580 assert (
1581 pCount == 3 and cCount == 0
1582 ), "Op.SCATTER must have 3 placeholders, 0 consts"
1583
1584 tens_ser_list = []
1585 for idx, shape in enumerate(shapeList):
1586 dtype = dtypeList[idx]
1587 if idx != 1:
1588 arr = testGen.getRandTensor(shape, dtype)
1589 tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
1590 else:
1591 # Create the indices array
1592 assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
1593 arr = []
1594 for n in range(shape[0]):
1595 # Get a shuffled list of output indices (0 to K-1) and
1596 # limit length to W
1597 arr.append(testGen.rng.permutation(K)[:W])
1598 indices_arr = np.array(arr, dtype=np.int32) # (N, W)
1599 # To match old functionality - create indices as CONST
1600 tens_ser_list.append(
1601 testGen.ser.addConst(shape, dtype, indices_arr)
1602 )
1603
1604 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
1605
1606 else:
1607 # ERROR_IF or floating point test
1608 # Use inclusive values upto index K for indices tensor
1609 data_range_list = (
1610 {"range": None},
1611 {"range": (0, K - 1)},
1612 {"range": None},
1613 )
1614 argsDict["data_range_list"] = data_range_list
1615
1616 return TosaTensorValuesGen.tvgLazyGenDefault(
1617 testGen, opName, dtypeList, shapeList, argsDict, error_name
1618 )
1619
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001620
1621class TosaArgGen:
1622 """Argument generators create exhaustive or random lists of attributes for
1623 operators that take attributes or other parameters.
1624
1625 The return value is a list of (descriptive_name, [arglist]) tuples where
1626 the descriptive_name is appended to the test name and the arglist is expanded
1627 as arguments to the operator build function.
1628 """
1629
1630 def __init__(self):
1631 pass
1632
1633 @staticmethod
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001634 def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001635 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001636 if (
1637 error_name is None
1638 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1639 and gtu.dtypeIsSupportedByCompliance(dtype)
1640 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001641 if dtype in [DType.FP16, DType.FP32, DType.BF16]:
1642 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1643 else:
1644 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1645 else:
1646 # Error test or No data generator types listed - assume random
1647 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1648
1649 # Expand arg list with other data generator types
1650 new_arg_list = []
1651 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001652 for arg_str, args_dict in arg_list:
1653 args_dict["dg_type"] = dg_type
Jeremy Johnson1271c442023-09-05 11:39:26 +01001654 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
Jeremy Johnson30476252023-11-20 16:15:30 +00001655 if error_name is None:
1656 num_test_sets = (
1657 args_dict["num_test_sets"]
1658 if "num_test_sets" in args_dict
1659 else 0
1660 )
1661 else:
1662 num_test_sets = 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01001663
1664 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1665 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001666 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001667 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001668 shape_info = (
1669 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1670 if "shape" in args_dict
1671 else ""
1672 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001673 print(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001674 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001675 )
1676 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001677 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001678 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001679 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001680
Jeremy Johnson30476252023-11-20 16:15:30 +00001681 num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
1682
1683 if num_test_sets > 0:
1684 for s in range(0, num_test_sets):
1685 new_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001686 new_args_dict = args_dict.copy()
1687 new_args_dict["s"] = s
1688 new_arg_list.append((new_arg_str, new_args_dict))
Jeremy Johnson30476252023-11-20 16:15:30 +00001689 else:
1690 # Default is a single test
1691 new_arg_list.append((arg_str, args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001692
1693 return new_arg_list
1694
1695 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001696 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1697 """A trivial argument generator for operators that don't take any
1698 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001699 arg_list = TosaArgGen._add_data_generators(
1700 testGen,
1701 opName,
1702 dtype,
1703 [("", {})],
1704 error_name,
1705 )
1706 # Return list of tuples: (arg_str, args_dict)
1707 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001708
1709 @staticmethod
Jeremy Johnson30476252023-11-20 16:15:30 +00001710 def agPow(testGen, opName, shapeList, dtype, error_name=None):
1711 """Pow operator needs different test sets to cover random numbers
1712 without creating NaNs or Infs"""
1713 arg_list = TosaArgGen._add_data_generators(
1714 testGen,
1715 opName,
1716 dtype,
1717 [("", {"num_test_sets": 3})],
1718 error_name,
1719 )
1720 # Return list of tuples: (arg_str, args_dict)
1721 return arg_list
1722
1723 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001724 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1725 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001726 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001727 shape = shapeList[0]
1728
1729 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001730 # Set too small axis
1731 axes = [testGen.rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001732 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001733 # Set too large axis
1734 axes = [testGen.rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001735 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001736 # Create tests for each dimension
1737 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001738
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001739 opid = testGen.TOSA_OP_LIST[opName]["op"]
1740
1741 for a in axes:
1742 args_dict = {"axis": int(a)}
1743 if opid == Op.REDUCE_SUM:
1744 args_dict["dot_products"] = gtu.product(shape)
1745 args_dict["shape"] = shape
1746 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1747 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1748
1749 arg_list.append(("axis{}".format(a), args_dict))
1750
1751 arg_list = TosaArgGen._add_data_generators(
1752 testGen,
1753 opName,
1754 dtype,
1755 arg_list,
1756 error_name,
1757 )
1758 # Return list of tuples: (arg_str, args_dict)
1759 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001760
1761 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001762 def _calculate_sparsity(num_tests, sparsity_factor):
1763 sparsity = num_tests // sparsity_factor + 1
1764 # If there are only a small number of tests, just select them all
1765 if sparsity < 13:
1766 sparsity = 1
1767 # To get a variety of parameter combinations sparsity should not be a
1768 # multiple of 2, 3 or 5
1769 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1770 sparsity += 1
1771 return sparsity
1772
1773 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001774 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001775 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001776 arg_list = []
1777
Jeremy Johnson0c716862023-04-13 17:18:19 +01001778 if testGen.args.level8k and error_name is not None:
1779 # Don't produce negative large tests
1780 return arg_list
1781
1782 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001783 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001784 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001785 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001786
Jeremy Johnson1271c442023-09-05 11:39:26 +01001787 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001788
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001789 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01001790 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001791 depthwise = opName.startswith("depthwise")
1792
1793 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001794 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001795 if error_name != ErrorIf.WrongRank:
1796 assert len(ifm_shape) == rank
1797 assert len(filter_shape) == rank
1798
Jeremy Johnson0c716862023-04-13 17:18:19 +01001799 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001800 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001801 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01001802 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001803 # compliance size - KS
1804 k_size = gtu.product(k_shape)
1805 if not depthwise:
1806 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001807
Jeremy Johnson0c716862023-04-13 17:18:19 +01001808 if not testGen.args.level8k:
1809 # Generate comprehensive argument lists
1810 # - except for named errors, which use specific invalid value(s)
1811 if error_name == ErrorIf.PadSmallerZero:
1812 p_vals = [testGen.rng.choice(range(-5, 0))]
1813 else:
1814 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1815 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1816 if error_name == ErrorIf.StrideSmallerOne:
1817 # Can't use stride=0, as it is used to derive output shape, as a divisor
1818 s_vals = [testGen.rng.choice(range(-5, 0))]
1819 else:
1820 # Stride must be greater than 1 to force non-integer error
1821 startStride = (
1822 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001823 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001824 s_vals = [
1825 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1826 ]
1827 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1828 if error_name == ErrorIf.DilationSmallerOne:
1829 d_vals = [testGen.rng.choice(range(-5, 1))]
1830 else:
1831 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1832 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001833
Jeremy Johnson0c716862023-04-13 17:18:19 +01001834 if not error_name and testGen.args.oversize:
1835 # add some oversize argument values
1836 if max(ifm_shape) < 64:
1837 bigPadding = 9
1838 paddings.update(
1839 {
1840 x
1841 for x in itertools.product(
1842 *([[0, bigPadding]] * (k_rank * 2))
1843 )
1844 }
1845 )
1846 bigStride = 8
1847 strides.update(
1848 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1849 )
1850 bigDilation = 7
1851 dilations.update(
1852 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1853 )
1854 max_dim_size = None
1855
1856 # There are too many parameter combinations, so generate them sparsely,
1857 # very sparse for negative tests
1858 sparsity_factor = 2 if error_name else 120
1859 sparsity = TosaArgGen._calculate_sparsity(
1860 len(paddings) * len(strides) * len(dilations), sparsity_factor
1861 )
1862 else:
1863 # Only test 8k levels boundaries
1864 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1865 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1866 bigPadding = bigKernel
1867
1868 dilation_shape = [1] * k_rank
1869 pad_shape = [0] * k_rank * 2
1870 if conv3d:
1871 # Small stride apart from for big kernel (see below) to keep
1872 # tensor size/calculation small
1873 stride_shape = [1] * k_rank
1874 for idx in range(k_rank):
1875 pad_offset = idx * 2
1876 if k_shape[idx] == bigKernel:
1877 # Padding shape needs to account for tensor shape
1878 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1879 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1880 # Big stride to reduce output size
1881 stride_shape[idx] = bigKernel
1882 else:
1883 # Account for kernel size
1884 pad_shape[pad_offset] = k_shape[idx] - 1
1885 else:
1886 # Always have a large stride with extra padding and dilation to keep
1887 # tensor calculation reasonable
1888 stride_shape = [bigKernel] * k_rank
1889 for idx in range(k_rank):
1890 # Dilation shape must account for kernel size
1891 dilation_shape[idx] = bigKernel // k_shape[idx]
1892 # Padding shape needs to accommodate tensor/kernel & dilation
1893 pad_offset = idx * 2
1894 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1895 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1896
1897 strides = {tuple(stride_shape)}
1898 dilations = {tuple(dilation_shape)}
1899 paddings = {tuple(pad_shape)}
1900 # Create a limit for the output dimensions size
1901 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1902
1903 # Currently allow all combinations that are reasonable size
1904 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001905
1906 n = 0
1907 for s in sorted(list(strides)):
1908 for p in sorted(list(paddings)):
1909 for d in sorted(list(dilations)):
1910 if (
1911 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001912 # the padded shape must exceed the dilation * kernel to get a positive
1913 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001914 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
1915 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01001916 and (
1917 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01001918 or (
1919 (ifm_shape[3] - 1 + p[4] + p[5])
1920 > d[2] * (k_shape[2] - 1)
1921 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01001922 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001923 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001924 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01001925 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001926 for index in range(k_rank):
1927 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001928 partial = (
1929 ifm_shape[index + 1]
1930 - 1
1931 + p[pad_offset]
1932 + p[pad_offset + 1]
1933 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001934 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001935 remainders.append(partial % s[index])
1936 outputs.append((partial // s[index]) + 1)
1937
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001938 if (
1939 # the parameters must produce integer exact output
1940 error_name != ErrorIf.ConvOutputShapeNonInteger
1941 and max(remainders) == 0
1942 ) or (
1943 error_name == ErrorIf.ConvOutputShapeNonInteger
1944 and max(remainders) > 0
1945 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001946 if (
1947 max_dim_size is not None
1948 and max(outputs) >= max_dim_size
1949 ):
1950 # Test will consume too much memory - skip it
1951 continue
1952
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001953 # Compliance - number of dot product calculations
1954 if depthwise:
1955 # TODO - add support
1956 dots = 0
1957 else:
1958 dots = gtu.product(
1959 (ifm_shape[0], *outputs, filter_shape[0])
1960 )
1961 args_dict = {
1962 "acc_type": accum_dtype,
1963 "stride": s,
1964 "pad": p,
1965 "dilation": d,
1966 "kernel": k_shape,
1967 "ks": k_size,
1968 "dot_products": dots,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001969 "shape": ifm_shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001970 }
1971
Jeremy Johnson0c716862023-04-13 17:18:19 +01001972 # Support for larger values than 9 needs different delimiter
1973 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001974 arg_list.append(
1975 (
James Ward8b390432022-08-12 20:48:56 +01001976 "acc{}_st{}_pad{}_dilat{}".format(
1977 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001978 delim.join([str(x) for x in s]),
1979 delim.join([str(x) for x in p]),
1980 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001981 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001982 args_dict,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001983 )
1984 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001985 n += 1
1986
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001987 arg_list = TosaArgGen._add_data_generators(
1988 testGen,
1989 opName,
1990 dtypes[0],
1991 arg_list,
1992 error_name,
1993 )
1994 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001995 return arg_list
1996
1997 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001998 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1999
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002000 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002001 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01002002
2003 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002004 accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01002005 elif error_name == ErrorIf.WrongInputType:
2006 # Pick some potentially correct output dtype if input type is incorrect
2007 accum_dtype = DType.INT32
2008 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01002009 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002010
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00002011 # Set up compliance info
2012 args_dict = {
2013 "acc_type": accum_dtype,
2014 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
2015 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
2016 "shape": shapeList[0],
2017 }
2018
2019 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
2020
2021 arg_list = TosaArgGen._add_data_generators(
2022 testGen,
2023 opName,
2024 input_dtype,
2025 arg_list,
2026 error_name,
2027 )
2028 # Return list of tuples: (arg_str, args_dict)
2029 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002030
2031 @staticmethod
2032 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
2033 # Get valid accumulate type(s)
2034 if dtype == DType.INT8:
2035 accum_dtypes = [DType.INT32]
2036 elif dtype == DType.INT16:
2037 accum_dtypes = [DType.INT48]
2038 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002039 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002040 elif dtype == DType.BF16:
2041 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002042 elif dtype == DType.FP32:
2043 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01002044 elif error_name is None:
2045 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
2046
2047 if error_name == ErrorIf.WrongOutputType:
2048 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson1271c442023-09-05 11:39:26 +01002049 accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01002050 elif error_name == ErrorIf.WrongInputType:
2051 # Pick some potentially correct output dtype if input type is incorrect
2052 accum_dtypes = [DType.INT32]
2053
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002054 # Set up compliance info
2055 args_dict = {
2056 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
2057 # Set dot_products = N*H*W
2058 "dot_products": gtu.product(
2059 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
2060 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002061 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002062 }
2063
2064 # Create arg tuple of string and dict
2065 arg_list = []
2066 for a in accum_dtypes:
2067 d = args_dict.copy()
2068 d["acc_type"] = a
2069 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01002070
2071 arg_list = TosaArgGen._add_data_generators(
2072 testGen,
2073 opName,
2074 dtype,
2075 arg_list,
2076 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01002077 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002078 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01002079 return arg_list
James Ward8b390432022-08-12 20:48:56 +01002080
2081 @staticmethod
2082 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002083 arg_list = []
2084
Jeremy Johnson0c716862023-04-13 17:18:19 +01002085 if testGen.args.level8k and error_name is not None:
2086 # Don't produce negative large tests
2087 return arg_list
2088
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002089 ifm_shape = shapeList[0]
2090 filter_shape = shapeList[1]
2091
Jeremy Johnson1271c442023-09-05 11:39:26 +01002092 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01002093
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002094 # Must be rank 4
2095 if error_name != ErrorIf.WrongRank:
2096 assert len(ifm_shape) == 4
2097 assert len(filter_shape) == 4
2098
Jeremy Johnson0c716862023-04-13 17:18:19 +01002099 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002100
Jeremy Johnson0c716862023-04-13 17:18:19 +01002101 if not testGen.args.level8k:
2102 # Generate comprehensive argument lists
2103 # - except for named errors, which use specific invalid value(s)
2104 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
2105 if error_name == ErrorIf.PadLargerEqualKernel:
2106 max_filter_size = -max(k_shape[0], k_shape[1])
2107 p_vals = [
2108 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
2109 ]
2110 else:
2111 p_vals = [
2112 x
2113 for x in range(
2114 smallest_padding_size, testGen.args.max_conv_padding + 1
2115 )
2116 ]
2117 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2118 if error_name == ErrorIf.StrideSmallerOne:
2119 # Can't use stride=0, as it is used to derive output shape, as a divisor
2120 s_vals = [testGen.rng.choice(range(-5, 0))]
2121 else:
2122 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
2123 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002124
Jeremy Johnson0c716862023-04-13 17:18:19 +01002125 if not error_name and testGen.args.oversize:
2126 # add some oversize argument values
2127 if max(ifm_shape) < 64:
2128 bigPadding = 9
2129 paddings.update(
2130 {
2131 x
2132 for x in itertools.product(
2133 *([[smallest_padding_size, bigPadding]] * 4)
2134 )
2135 }
2136 )
2137 bigStride = 8
2138 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
2139
2140 # There are too many parameter combinations, so generate them sparsely,
2141 # very sparse for negative tests
2142 sparsity_factor = 2 if error_name else 10
2143 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
2144 # If there are only a small number of tests, just select them all
2145 if sparsity < 13:
2146 sparsity = 1
2147 # To get a variety of parameter combinations sparsity should not be a
2148 # multiple of 2, 3 or 5
2149 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
2150 sparsity += 1
2151 else:
2152 # Only test 8k levels boundaries
2153 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2154 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2155 bigPadding = bigKernel
2156
2157 pad_shape = [0] * (len(k_shape) * 2)
2158 stride_shape = [1] * len(k_shape)
2159 # The point at which input dimension combined with the stride will
2160 # create large output sizes!
2161 LARGE_SIZE = 2
2162 for idx in range(len(k_shape)):
2163 pad_offset = idx * 2
2164 if k_shape[idx] == bigKernel:
2165 # Set large stride
2166 stride_shape[idx] = bigKernel
2167 # Use negative output padding to reduce shape size
2168 pad_shape[pad_offset] = -(bigPadding - 1)
2169 if ifm_shape[idx + 1] > LARGE_SIZE:
2170 pad_shape[pad_offset + 1] = -(bigPadding - 1)
2171 else:
2172 # The other dimension should be the bigKernel
2173 alt_idx = 1 - idx
2174 if (
2175 k_shape[alt_idx] == bigKernel
2176 and ifm_shape[alt_idx + 1] < LARGE_SIZE
2177 ):
2178 # As the input is small, the large stride won't
2179 # affect the output so we can add some padding
2180 pad_shape[pad_offset + 1] = bigPadding
2181
2182 strides = {tuple(stride_shape)}
2183 paddings = {tuple(pad_shape)}
2184
2185 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002186 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002187
2188 n = 0
2189 for s in sorted(list(strides)):
2190 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07002191 if n % sparsity == 0:
2192 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01002193 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
2194 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07002195 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01002196
2197 # Support for larger values than 9 needs different delimiter
2198 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07002199 arg_list.append(
2200 (
James Ward8b390432022-08-12 20:48:56 +01002201 "acc{}_st{}_pad{}_os{}".format(
2202 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01002203 delim.join([str(x) for x in s]),
2204 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07002205 "x".join([str(x) for x in os]),
2206 ),
James Ward8b390432022-08-12 20:48:56 +01002207 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002208 )
TatWai Chong24594f52022-06-08 00:48:04 -07002209 )
2210 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002211
2212 return arg_list
2213
2214 @staticmethod
2215 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002216 rank = len(shapeList[0])
2217
2218 # Exhaustively test combinations of padding on each side of each dimension
2219 # - the range of padding values is defined by pad_min and pad_max
2220 # - for padding >9, the name format needs to be more distinctive
2221 pad_min, pad_max = 0, 1
2222 pad_values = [x for x in range(pad_min, pad_max + 1)]
2223 if error_name == ErrorIf.PadSmallerZero:
2224 pad_values = [x for x in range(-2, 0)]
2225 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
2226 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
2227
2228 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
2229 pad_const_int = testGen.getRandNumberDType(dtype)
2230 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00002231 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002232 pad_const_int = 0
2233 pad_const_fp = testGen.getRandNumberDType(dtype)
2234 else:
2235 return []
2236
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002237 list_shape_pad_values = list(shape_pad_values)
2238 # If we are producing tests for rank 6 or greater use sparsity
2239 if len(list_shape_pad_values) > 1024:
2240 sparsity_factor = 2 if error_name else 120
2241 sparsity = TosaArgGen._calculate_sparsity(
2242 len(list_shape_pad_values), sparsity_factor
2243 )
2244 else:
2245 sparsity = 1
2246
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002247 # Build arg list
2248 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002249 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01002250 paddings = list(paddings)
2251 args_valid = True
2252
2253 if error_name == ErrorIf.PadSmallerZero:
2254 # Prevent negative output shapes while ensuring still testing for negative padding
2255 for i in range(rank):
2256 dim_after_padding = (
2257 paddings[i][0] + paddings[i][1] + shapeList[0][i]
2258 )
2259 if dim_after_padding < 1:
2260 paddings[i] = (0, 0)
2261 if all([p > -1 for p in paddings[i]]):
2262 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002263 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01002264 name = "pad"
2265 for r in range(rank):
2266 before, after = paddings[r]
2267 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002268 args_dict = {
2269 "pad": np.array(paddings),
2270 "pad_const_int": pad_const_int,
2271 "pad_const_fp": pad_const_fp,
2272 }
2273 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01002274
2275 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
2276 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002277
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002278 arg_list = TosaArgGen._add_data_generators(
2279 testGen,
2280 opName,
2281 dtype,
2282 arg_list,
2283 error_name,
2284 )
2285
2286 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002287 return arg_list
2288
2289 @staticmethod
2290 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
2291 arg_list = []
2292
2293 shape = shapeList[0]
2294 if error_name != ErrorIf.WrongRank:
2295 assert len(shape) == 4
2296
Jeremy Johnson0c716862023-04-13 17:18:19 +01002297 test_level8k = testGen.args.level8k and error_name is None
2298
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002299 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01002300 startKernel = 2
2301 startPad = 0
2302 if not test_level8k:
2303 # Generate comprehensive argument lists
2304 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
2305 paddings = {x for x in itertools.product(*([p_vals] * 4))}
2306 # Stride must be greater than 1 to force non-integer error
2307 s_vals = [
2308 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
2309 ]
2310 strides = {x for x in itertools.product(*([s_vals] * 2))}
2311 k_vals = [
2312 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
2313 ]
2314 kernels = {x for x in itertools.product(*([k_vals] * 2))}
2315 max_dim_size = None
2316 else:
2317 # Only test 8k levels
2318 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
2319 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
2320 strides = {(1, bigStride), (bigStride, 4)}
2321 kernels = {(1, bigKernel), (bigKernel, 3)}
2322 paddings = set()
2323 for s in sorted(list(strides)):
2324 for k in sorted(list(kernels)):
2325 padding = []
2326 for idx in range(len(k)):
2327 total_padding = s[idx] - shape[idx + 1] + k[idx]
2328 while total_padding < 0:
2329 # Must meet: shape + padding > kernel
2330 total_padding += s[idx]
2331 if total_padding < k[idx]:
2332 padding.extend([0, total_padding])
2333 else:
2334 # Note this may produce padding >= k[idx] which is not
2335 # allowed - but will be ignored in the creation loop below
2336 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
2337 paddings.add(tuple(padding))
2338 # Create a limit for the output dimensions size
2339 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002340
James Ward8b390432022-08-12 20:48:56 +01002341 if opName == "max_pool2d":
2342 accum_dtypes = [None] # max_pool has no accumulate dtype
2343 elif dtype == DType.INT8 or dtype == DType.INT16:
2344 accum_dtypes = [DType.INT32]
2345 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002346 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002347 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002348 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01002349 elif error_name is None:
2350 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
2351 else:
2352 # Set to something for the ErrorIf case which has
2353 # incorrect input data-type
2354 accum_dtypes = [DType.INT32]
2355
Jeremy Johnson0c716862023-04-13 17:18:19 +01002356 if not test_level8k:
2357 if testGen.args.oversize:
2358 # add some oversize argument values
2359 bigStride = 7
2360 bigKernel = 9
2361 strides.update(
2362 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002363 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01002364 kernels.update(
2365 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
2366 )
2367 if max(shape) < 64:
2368 # padding must be less than the kernel size
2369 bigPadding = bigKernel - 1
2370 paddings.update(
2371 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
2372 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002373
Jeremy Johnson0c716862023-04-13 17:18:19 +01002374 # There are too many parameter combinations, so generate them sparsely,
2375 # very sparse for negative tests
2376 sparsity_factor = 2 if error_name else 500
2377 sparsity = (
2378 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
2379 )
2380 else:
2381 # We have already limited test output combinations for 8k tests
2382 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002383
James Ward8b390432022-08-12 20:48:56 +01002384 arg_str = (
2385 "acc{}_st{}_kern{}_pad{}"
2386 if accum_dtypes[0] is not None
2387 else "st{}_kern{}_pad{}"
2388 )
2389
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002390 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01002391 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002392 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01002393
2394 # Support for larger values than 9 needs different delimiter
2395 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01002396 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01002397 delim.join([str(x) for x in stride]),
2398 delim.join([str(x) for x in kern]),
2399 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01002400 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002401 args_dict = {
2402 "stride": stride,
2403 "pad": pad,
2404 "kernel": kern,
2405 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002406 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002407 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
2408 }
James Ward8b390432022-08-12 20:48:56 +01002409
2410 if accum is not None:
2411 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002412 args_dict["acc_type"] = accum
2413 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01002414
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002415 n = 0
James Ward8b390432022-08-12 20:48:56 +01002416 for a in accum_dtypes:
2417 for s in sorted(list(strides)):
2418 for p in sorted(list(paddings)):
2419 for k in sorted(list(kernels)):
2420 if error_name in [
2421 ErrorIf.StrideSmallerOne,
2422 ErrorIf.KernelSmallerOne,
2423 ErrorIf.PadSmallerZero,
2424 ErrorIf.PadLargerEqualKernel,
2425 ]:
2426 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
2427 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002428 )
James Ward8b390432022-08-12 20:48:56 +01002429 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002430 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002431 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002432 )
James Ward8b390432022-08-12 20:48:56 +01002433 elif (
2434 n % sparsity == 0
2435 # padding must not exceed the kernel size
2436 and p[0] < k[0]
2437 and p[1] < k[0]
2438 and p[2] < k[1]
2439 and p[3] < k[1]
2440 # the padded shape must exceed the kernel size
2441 and (shape[1] + p[0] + p[1]) > k[0]
2442 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01002443 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002444 partial_h = shape[1] + p[0] + p[1] - k[0]
2445 partial_w = shape[2] + p[2] + p[3] - k[1]
2446 remainder_h = partial_h % s[0]
2447 remainder_w = partial_w % s[1]
2448 output_h = partial_h // s[0] + 1
2449 output_w = partial_w // s[1] + 1
2450 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01002451 if (
2452 # the parameters must produce integer exact output
2453 error_name != ErrorIf.PoolingOutputShapeNonInteger
2454 and remainder_h == 0
2455 and remainder_w == 0
2456 ) or (
2457 error_name == ErrorIf.PoolingOutputShapeNonInteger
2458 and (remainder_h != 0 or remainder_w != 0)
2459 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002460 if (
2461 max_dim_size is not None
2462 and max(output_h, output_w) > max_dim_size
2463 ):
2464 # Test will consume too much memory - skip it
2465 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002466 # Dot products = N*OH*OW*C
2467 dp = gtu.product(
2468 (shape[0], output_h, output_w, shape[3])
2469 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002470 arg_list.append(
2471 get_arg_list_element(a, s, p, k, dp, shape)
2472 )
James Ward8b390432022-08-12 20:48:56 +01002473 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002474
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002475 # Now add data generator types
2476 arg_list = TosaArgGen._add_data_generators(
2477 testGen,
2478 opName,
2479 dtype,
2480 arg_list,
2481 error_name,
2482 )
2483
2484 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002485 return arg_list
2486
2487 @staticmethod
2488 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
2489 arg_list = []
2490
2491 # Enumerate the output types here
2492 if error_name == ErrorIf.WrongOutputType:
2493 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
2494 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002495 dtypeList = [
2496 DType.BOOL,
2497 DType.INT16,
2498 DType.INT32,
2499 DType.FP16,
2500 DType.BF16,
2501 DType.FP32,
2502 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002503 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002504 dtypeList = [
2505 DType.BOOL,
2506 DType.INT8,
2507 DType.INT32,
2508 DType.FP16,
2509 DType.BF16,
2510 DType.FP32,
2511 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002512 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002513 dtypeList = [
2514 DType.BOOL,
2515 DType.INT8,
2516 DType.INT16,
2517 DType.FP16,
2518 DType.BF16,
2519 DType.FP32,
2520 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002521 elif inDtype == DType.BOOL:
2522 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002523 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00002524 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002525 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00002526 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002527 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00002528 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002529 elif error_name == ErrorIf.WrongInputType:
2530 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002531 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002532 else:
2533 raise Exception("Unexpected input dtype: {}".format(inDtype))
2534
2535 for dtype in dtypeList:
Jeremy Johnson708da822023-11-15 16:25:45 +00002536 arg_list.append(
2537 ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype})
2538 )
2539
2540 # Now add data generator types
2541 arg_list = TosaArgGen._add_data_generators(
2542 testGen,
2543 opName,
2544 dtype,
2545 arg_list,
2546 error_name,
2547 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002548
2549 return arg_list
2550
2551 @staticmethod
2552 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
2553 arg_list = []
2554
2555 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002556 for outDtype in [
2557 DType.UINT8,
2558 DType.INT8,
2559 DType.INT16,
2560 DType.INT32,
2561 DType.UINT16,
2562 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002563 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002564 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002565 and error_name == ErrorIf.OutputZeroPointNotZero
2566 ):
2567 continue
2568 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002569 outDtype != DType.UINT16
2570 and error_name == ErrorIf.U16OutputZeroPointNotValid
2571 ) or (
2572 inDtype != DType.UINT16
2573 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002574 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002575 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002576 continue
2577 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002578 inDtype == DType.UINT8
2579 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002580 and error_name != ErrorIf.WrongOutputType
2581 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002582 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2583 continue
2584 if (
2585 inDtype not in [DType.INT8, DType.INT16]
2586 and outDtype == DType.UINT8
2587 and error_name != ErrorIf.WrongOutputType
2588 ):
2589 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2590 continue
2591 if (
2592 inDtype == DType.UINT16
2593 and outDtype != DType.INT16
2594 and error_name != ErrorIf.WrongOutputType
2595 ):
2596 # The only output dtype for UINT16 is INT16, skip all others
2597 continue
2598 if (
2599 inDtype != DType.INT16
2600 and outDtype == DType.UINT16
2601 and error_name != ErrorIf.WrongOutputType
2602 ):
2603 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002604 continue
2605 if (
2606 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002607 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002608 ):
2609 continue
2610
2611 for scale32 in [False, True]:
2612 if error_name == ErrorIf.ScaleTrue and not scale32:
2613 continue
2614 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2615 continue
2616 for double_round in [False, True]:
2617 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2618 continue
2619 for per_channel in [False, True]:
2620
2621 if (
2622 inDtype == DType.INT48
2623 and scale32
2624 and error_name != ErrorIf.ScaleTrue
2625 ):
2626 # Illegal condition. Must be scale32=False
2627 continue
2628 if (
2629 double_round
2630 and not scale32
2631 and error_name != ErrorIf.ScaleNotTrue
2632 ):
2633 # Illegal condition. ERROR_IF(!scale32 && double_round)
2634 continue
2635
2636 arg_list.append(
2637 (
2638 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002639 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002640 int(scale32),
2641 int(double_round),
2642 int(per_channel),
2643 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002644 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002645 )
2646 )
2647
2648 return arg_list
2649
2650 @staticmethod
2651 def agMul(testGen, opName, shapeList, dtype, error_name=None):
2652 arg_list = []
2653
2654 if dtype is DType.INT32:
2655 for p in range(testGen.args.num_rand_permutations):
2656
2657 shift = testGen.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002658 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002659 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002660 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002661
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002662 arg_list = TosaArgGen._add_data_generators(
2663 testGen,
2664 opName,
2665 dtype,
2666 arg_list,
2667 error_name,
2668 )
2669 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002670 return arg_list
2671
2672 @staticmethod
2673 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
2674 arg_list = []
2675
2676 arg_list.append(("roundTrue", [True]))
2677 arg_list.append(("roundFalse", [False]))
2678
2679 return arg_list
2680
Luke Hutton57287132023-02-06 14:54:18 +00002681 @staticmethod
2682 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2683 arg_list = []
2684
2685 arg_list.append(("inverseTrue", [True]))
2686 arg_list.append(("inverseFalse", [False]))
2687
2688 return arg_list
2689
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002690 # Helper function for reshape. Gets some factors of a larger number.
2691 @staticmethod
2692 def getFactors(val, start=1):
2693 factors = []
2694
2695 for i in range(start, int(np.sqrt(val)) + 1):
2696 if (val % i) == 0:
2697 factors.append(i)
2698
2699 return factors
2700
2701 @staticmethod
2702 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
2703 arg_list = []
2704
2705 origShape = shapeList[0]
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002706 totalElements = gtu.product(origShape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002707 factors = TosaArgGen.getFactors(totalElements)
2708
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002709 # Find new shapes up to the number of permutations asked for
2710 # This code is NOT fast. Fortunately, the numbers are fairly small.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002711 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002712 # Rank from 1 to TOSA_TENSOR_MAX_RANK
2713 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002714 if len(factors) < newRank:
2715 continue
2716
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002717 # escape_counter limits the generation of new shapes to a reasonable time
2718 for escape_counter in range(100):
2719
2720 # Generate the new shape of the chosen new rank
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002721 newShape = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002722 remainingElements = totalElements
2723 shuffledFactors = testGen.rng.permutation(factors)
2724 for i in range(1, newRank):
2725 # pick rank-1 factors
2726 newShape.append(shuffledFactors[0])
2727 remainingElements = remainingElements // shuffledFactors[0]
2728 shuffledFactors = testGen.rng.permutation(
2729 TosaArgGen.getFactors(remainingElements)
2730 )
2731 newShape.append(remainingElements)
2732
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002733 # Check for duplicates
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002734 duplicate = False
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002735 for name, args_dict in arg_list:
2736 if args_dict["new_shape"] == newShape:
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002737 duplicate = True
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002738 break
2739
Jeremy Johnsone1e611d2023-12-13 14:28:12 +00002740 if not duplicate:
2741 outShape = "x".join([str(x) for x in newShape])
2742 arg_list.append(
2743 (
2744 "perm{}_rank{}_out{}".format(p, newRank, outShape),
2745 {"new_shape": newShape},
2746 )
2747 )
2748 # Found an output shape for this permutation
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002749 break
2750
Jeremy Johnsonfe79acc2023-11-29 15:57:58 +00002751 # Now add data generator types
2752 arg_list = TosaArgGen._add_data_generators(
2753 testGen,
2754 opName,
2755 dtype,
2756 arg_list,
2757 error_name,
2758 )
2759
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002760 return arg_list
2761
2762 @staticmethod
2763 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
2764 arg_list = []
2765
2766 ifm_shape = shapeList[0]
2767
2768 if error_name == ErrorIf.IndexOutsideBounds:
2769 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
2770 incorrect_small_index = range(-len(ifm_shape), 0)
2771 permutations = [p for p in itertools.permutations(incorrect_large_index)]
2772 permutations.extend(
2773 [p for p in itertools.permutations(incorrect_small_index)]
2774 )
2775 elif error_name == ErrorIf.IndexUsedTwice:
2776 # Create list with a duplicated index
2777 perm_range = list(range(len(ifm_shape)))
2778 index_choice = testGen.rng.choice(range(len(perm_range)))
2779 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
2780 permutations = [p for p in itertools.permutations(perm_range)]
2781
2782 else:
2783 # Get all permutations
2784 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
2785
2786 # Limit to possible permutations from shape dimension or argument setting
2787 limit = min(len(permutations), testGen.args.num_rand_permutations)
2788
2789 # Get random permutation generator that uses all permutations
2790 random_permutations = testGen.rng.permutation(permutations)
2791
2792 # Create list of required amount of permutations
2793 arg_list = [
2794 ("perm{}".format(p), [random_permutations[p].tolist()])
2795 for p in range(limit)
2796 ]
2797 return arg_list
2798
2799 @staticmethod
2800 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
2801 arg_list = []
2802
2803 ifm_shape = shapeList[0]
2804 rank = len(ifm_shape)
2805
2806 for p in range(testGen.args.num_rand_permutations):
2807 start = []
2808 size = []
2809
2810 valid = True
2811
2812 for i in range(rank):
2813 if ifm_shape[i] > 1:
2814 start.append(testGen.randInt(0, ifm_shape[i]))
2815 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
2816
2817 # Invalid slice size?
2818 if size[i] == 0:
2819 valid = False
2820 else:
2821 start.append(0)
2822 size.append(1)
2823
2824 if valid:
2825 # If ERROR_IF test required then incorrect start, size will be returned
2826 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
2827 testGen, error_name, ifm_shape, start, size
2828 )
2829 arg_list.append(("perm{}".format(p), [start, size]))
2830 return arg_list
2831
2832 @staticmethod
2833 def agTile(testGen, opName, shapeList, dtype, error_name=None):
2834 arg_list = []
2835
2836 ifm_shape = shapeList[0]
2837 rank = len(ifm_shape)
2838
2839 for p in range(testGen.args.num_rand_permutations):
2840
2841 # Pick a few random, but small multiple values
2842 # because otherwise this has a tendency to generate
2843 # enormous tensors
2844 multiples = []
2845 for i in range(rank):
2846 if ifm_shape[i] > 1000:
2847 # Multiple of 1 if ifm_shape dimension is large to reduce
2848 # tensor size
2849 multiples.append(1)
2850 elif max(ifm_shape) > 1000:
2851 multiples.append(2)
2852 else:
2853 multiples.append(testGen.randInt(1, 4))
2854 arg_list.append(("perm{}".format(p), [multiples]))
2855
2856 return arg_list
2857
2858 @staticmethod
2859 def agResize(testGen, opName, shapeList, dtype, error_name=None):
2860 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002861 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002862
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002863 def get_aspect_ratio_resize_params():
2864 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
2865 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
2866 invert = testGen.rng.choice((False, True))
2867 letterbox = testGen.rng.choice((False, True))
2868
2869 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
2870 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
2871 scale_y_d = scale_x_d = 1
2872 offset_x = offset_y = 0
2873
2874 if letterbox:
2875 max_border = scale_y_n
2876 border_y = testGen.randInt(low=0, high=max_border)
2877 border_x = 0
2878 else:
2879 # Pillarboxing
2880 border_y = 0
2881 max_border = scale_x_n
2882 border_x = testGen.randInt(low=0, high=max_border)
2883
2884 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2885 offset = (offset_y, offset_x)
2886 border = (border_y, border_x)
2887
2888 return scale, offset, border
2889
2890 def get_upscale_downscale_params():
2891 valid_params = False
2892 while not valid_params:
2893 upscale = testGen.rng.choice((False, True))
2894
2895 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
2896 origin_sampling = testGen.rng.choice((False, True))
2897
2898 if upscale:
2899 shift = testGen.randInt(low=1, high=4)
2900 scale_x_d = scale_y_d = 1
2901 scale_x_n = scale_y_n = (
2902 1 << shift if origin_sampling else 2 << shift
2903 )
2904 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
2905 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
2906 else:
2907 scale_x_n = 1
2908 scale_y_n = 1
2909
2910 # Return list of valid scale_*_d values (max value 4) given input dim shape
2911 def get_valid_denom(ifm_dim):
2912 return [x for x in range(1, 5) if ifm_dim % x == 1]
2913
2914 # Generate list of valid downscale values and choose one randomly
2915 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
2916 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
2917
2918 if not valid_scale_y_ds and not valid_scale_x_ds:
2919 # Bad parameters, skip
2920 continue
2921
2922 if not valid_scale_y_ds:
2923 scale_y_d = 1
2924 else:
2925 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
2926
2927 if not valid_scale_x_ds:
2928 scale_x_d = 1
2929 else:
2930 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
2931
2932 border_x = border_y = 0
2933 offset_y = testGen.randInt(0, 16 * scale_y_n)
2934 offset_x = testGen.randInt(0, 16 * scale_x_n)
2935 valid_params = True
2936
2937 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2938 offset = (offset_y, offset_x)
2939 border = (border_y, border_x)
2940 return scale, offset, border
2941
2942 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002943 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
2944 scale = scale_n / scale_d
2945 if scale > max_scale:
2946 factor = scale / max_scale
2947 new_scale_d = math.ceil(scale_d * factor)
2948 assert scale_n / new_scale_d <= max_scale
2949 scale_d = new_scale_d
2950 return scale_d
2951
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002952 # Scale
2953 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
2954 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
2955
2956 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
2957 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
2958
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002959 scale_y_d = fix_scale_to_max_scale(
2960 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2961 )
2962 scale_x_d = fix_scale_to_max_scale(
2963 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2964 )
2965
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002966 # Offsets and border within the scale
2967 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
2968 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
2969 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
2970 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
2971
2972 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2973 offset = (offset_y, offset_x)
2974 border = (border_y, border_x)
2975 return scale, offset, border
2976
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002977 def get_level_8k_params():
2978 # Create 64x scale - 64/1 to 2048/32
2979 scale_d = testGen.randInt(
2980 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
2981 )
2982 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
2983 # Create half to fifth scaling
2984 scale_d_alt = testGen.randInt(low=2, high=6)
2985 scale_n_alt = 1
2986 switch = testGen.rng.choice((False, True))
2987 if switch:
2988 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
2989 else:
2990 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
2991
2992 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
2993 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
2994 offset = (offset_y, offset_x)
2995 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
2996 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
2997 border = (border_y, border_x)
2998 return scale, offset, border
2999
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003000 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003001 # Exclude illegal {mode, type} configurations. Pick legal output types
3002 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
3003 outputDTypeList = [DType.INT8]
3004 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
3005 outputDTypeList = [DType.INT16]
3006 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
3007 outputDTypeList = [DType.INT32]
3008 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
3009 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01003010 elif dtype == DType.FP16:
3011 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01003012 elif dtype == DType.BF16:
3013 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01003014 elif dtype == DType.FP32:
3015 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003016 elif error_name == ErrorIf.WrongInputType:
3017 # If an incorrect input type is used then we set a 'correct'
3018 # output type to avoid other errors
3019 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
3020 else:
3021 continue
3022
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003023 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
3024
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003025 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003026 perm = 0
3027 while perm < testGen.args.num_rand_permutations:
3028 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003029 if not testGen.args.level8k:
3030 _rnd_param_fn = testGen.rng.choice(
3031 (
3032 get_rand_params,
3033 get_upscale_downscale_params,
3034 get_aspect_ratio_resize_params,
3035 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003036 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003037 scale, offset, border = _rnd_param_fn()
3038 else:
3039 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003040
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003041 # Expand params for bounds-checking
3042 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
3043 (offset_y, offset_x) = offset
3044 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003045
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003046 # Make sure output dimensions OH and OW are integers
3047 partial_output_y = (
3048 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
3049 )
3050 partial_output_x = (
3051 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
3052 )
3053 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003054 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003055 if (
3056 partial_output_y % scale_y_d == 0
3057 and partial_output_x % scale_x_d == 0
3058 ):
3059 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003060 if perm > 0:
3061 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003062 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003063 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003064 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003065 while partial_output_y % scale_y_d != 0:
3066 scale_y_d -= 1
3067 while partial_output_x % scale_x_d != 0:
3068 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003069 # Make sure we are still within max scaling
3070 if (
3071 scale_y_n / scale_y_d
3072 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
3073 scale_x_n / scale_x_d
3074 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
3075 # Skip the test as it is using too large a scaling factor
3076 if perm > 0:
3077 perm += 1
3078 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003079
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003080 output_y = partial_output_y // scale_y_d + 1
3081 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003082
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003083 if (
3084 output_y >= testGen.args.max_resize_output_dim
3085 or output_x >= testGen.args.max_resize_output_dim
3086 ) and error_name is None:
3087 # Skip positive test if output dim will be too high
3088 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01003089 if not testGen.args.level8k or perm > 0:
3090 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003091 continue
3092
3093 if (
3094 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003095 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003096 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01003097 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003098 ):
3099 # Output dimensions out of scope
3100 if error_name is not None and perm > 0:
3101 # As long as we have one ERROR_IF test, don't worry
3102 # about creating all the other permutations
3103 perm += 1
3104 continue
3105
3106 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
3107 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003108 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003109 and output_y - scale_y_d < 1
3110 )
3111 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01003112 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003113 and output_x - scale_x_d < 1
3114 )
3115 ):
3116 # Can't create a negative test with these params as it
3117 # will create invalid output size
3118 if perm > 0:
3119 perm += 1
3120 continue
3121
3122 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
3123 offset = [offset_y, offset_x]
3124 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003125
3126 # Common for all data types
3127 if error_name is not None:
3128 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003129 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003130 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003131 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003132 outputDTypeNew,
3133 ) = TosaErrorIfArgGen.eiResizeErrorIf(
3134 testGen,
3135 error_name,
3136 mode,
3137 dtype,
3138 shapeList,
3139 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003140 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003141 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003142 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003143 )
3144 else:
3145 outputDTypeNew = outputDType
3146
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003147 arg_to_append = (
3148 arg_str.format(
3149 "N" if mode == ResizeMode.NEAREST else "B",
3150 testGen.typeStr(outputDTypeNew),
3151 scale[0],
3152 scale[1],
3153 scale[2],
3154 scale[3],
3155 offset[0],
3156 offset[1],
3157 border[0],
3158 border[1],
3159 ),
3160 [
3161 mode,
3162 scale,
3163 offset,
3164 border,
3165 dtype,
3166 outputDTypeNew,
3167 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003168 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003169 if arg_to_append in arg_list:
3170 # Skip already generated test params
3171 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003172
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01003173 # Valid permutation
3174 perm += 1
3175 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003176 return arg_list
3177
3178 @staticmethod
3179 def agTable(testGen, opName, shapeList, dtype, error_name=None):
3180 arg_list = []
3181
3182 if dtype == DType.INT8:
3183 table = np.int32(
3184 testGen.rng.integers(low=-128, high=128, size=[256])
3185 ).tolist()
3186 else: # INT16
3187 table = np.int32(
3188 testGen.rng.integers(low=-32768, high=32768, size=[513])
3189 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07003190 # Make sure all slopes are within REQUIRE min/max 16-bit int
3191 for idx in range(len(table) - 1):
3192 slope = table[idx + 1] - table[idx]
3193 # Alter the next table entry to force the slope to be ok
3194 if slope > 32767:
3195 table[idx + 1] -= slope - 32767
3196 if slope < -32768:
3197 table[idx + 1] -= slope + 32768
3198 slope = table[idx + 1] - table[idx]
3199 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01003200 arg_list.append(
3201 (
3202 "",
3203 [table],
3204 )
3205 )
3206 return arg_list
3207
3208 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
3209 # CondIf generates the condition values here.
3210 # Convert to tensors in the build function, along with the
3211 # then and else blocks
3212 arg_list = []
3213
3214 for c in [False, True]:
3215 arg_list.append(("cond{}".format(int(c)), [c]))
3216
3217 return arg_list
3218
3219 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
3220 # While loop: 0 iterations, 1, more than 1
3221 arg_list = []
3222
3223 for iter in [0, 1, 4]:
3224 arg_list.append(("iter{}".format(iter), [iter]))
3225
3226 return arg_list