blob: 1f548518868fbdab3f227ee22fd572b79a5bcd2f [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
19
20class TosaQuantGen:
21 """QuantizedInfo random generator helper functions.
22
23 Specify with 'qgen': in the operator defintion.
24 """
25
26 def __init__(self):
27 pass
28
29 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010031
32 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010033 if testGen.args.zeropoint is not None:
34 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035 return testGen.randInt(-128, 128)
36 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010037 if testGen.args.zeropoint is not None:
38 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 return testGen.randInt(0, 256)
40 elif error_name in [
41 ErrorIf.InputZeroPointNotZero,
42 ErrorIf.WeightZeroPointNotZero,
43 ErrorIf.OutputZeroPointNotZero,
44 ]:
45 zero_point = testGen.randInt(-128, 128)
46 if zero_point == 0:
47 zero_point = 1
48 return zero_point
49 return 0
50
51 @staticmethod
52 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010053 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000054 qinfo = [
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 TosaQuantGen.getZeroPoint(testGen, dtype),
57 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010058 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000059 qinfo = [
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
62 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010063 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000064 qinfo = [
65 TosaQuantGen.getZeroPoint(testGen, dtype),
66 TosaQuantGen.getZeroPoint(testGen, dtype),
67 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068 return qinfo
69
70 @staticmethod
71 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010072 if isinstance(dtype_or_dtypeList, list):
73 # a list of [input, weights, accumulator] dtypes
74 dtypeList = dtype_or_dtypeList
75 else:
76 # an int, [input, weights, accumulator] dtypes are the same
77 dtypeList = [dtype_or_dtypeList] * 3
78
79 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000080 qinfo = [
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
82 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
83 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000085 qinfo = [
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
87 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
88 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000090 qinfo = [
91 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
92 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
93 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 return qinfo
95
96 @staticmethod
97 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000099 qinfo = [
100 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
101 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
102 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000104 qinfo = [
105 TosaQuantGen.getZeroPoint(testGen, dtype),
106 TosaQuantGen.getZeroPoint(testGen, dtype),
107 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108 return qinfo
109
110 @staticmethod
111 def computeMultiplierAndShift(scaleFp, scale32):
112 # Derived from computeMultiplierAndShiftTosaScale32
113 # Provide a floating-point scaling factor and the scale32 parameter
114 # to compute the multiplier and shift
115
116 if scale32:
117 scaleBits = 31
118 else:
119 scaleBits = 15
120
121 m, shift = math.frexp(scaleFp)
122
123 if scaleFp < 0.0:
124 m = -m
125
126 multiplier = round(m * (1 << scaleBits))
127 assert multiplier <= (1 << scaleBits)
128
129 if multiplier == (1 << scaleBits):
130 multiplier = multiplier // 2
131 shift = shift + 1
132
133 shift = (-shift) + scaleBits
134 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
135 # scaleFp, scaleBits, m, multiplier, shift))
136
137 # Adjust multiplier such that shift is in allowed value range.
138 if shift == 0:
139 multiplier = multiplier // 4
140 shift = shift + 2
141 elif shift == 1:
142 multiplier = multiplier // 2
143 shift = shift + 1
144 elif shift == 63:
145 multiplier = multiplier * 2
146 shift = shift - 1
147
148 assert multiplier <= (1 << scaleBits)
149 assert shift >= 2 and shift <= 62
150
151 return multiplier, shift
152
153
154class TosaTensorGen:
155 """Tensor generators create a shape list for the placeholder and const tensor
156 data operands for the operator.
157
158 The actual random data is generated separately for each test.
159 """
160
161 def __init__(self):
162 pass
163
164 @staticmethod
165 def tgBasic(testGen, opName, rank, error_name=None):
166 pl, const = opName["operands"]
167 shape = testGen.makeShape(rank)
168
169 # Constrict the overall size of the shape when creating ERROR_IF tests
170 if error_name:
171 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
172
173 shape_list = []
174 for i in range(pl + const):
175 shape_list.append(shape.copy())
176
Luke Huttona4e48ca2023-02-22 11:53:48 +0000177 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 if error_name == ErrorIf.RankMismatch:
179 if rank == 1 and i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
181 elif i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
183
184 return shape_list
185
186 @staticmethod
187 def tgNHWC(testGen, opName, rank, error_name=None):
188 pl, const = opName["operands"]
189
190 if error_name != ErrorIf.WrongRank:
191 assert rank == 4
192
193 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000194 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name and error_name != ErrorIf.MaxDimExceeded:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
207 def tgScatter(testGen, opName, rank, error_name=None):
208 pl, const = opName["operands"]
209
210 assert pl == 2
211 assert const == 0
212 if error_name != ErrorIf.WrongRank:
213 assert rank == 3
214
215 values_in_shape = testGen.makeShape(rank)
216
217 # ignore max batch size if target shape is set
218 if testGen.args.max_batch_size and not testGen.args.target_shapes:
James Ward30124a82023-02-02 14:56:33 +0000219 values_in_shape[0] = min(values_in_shape[0], testGen.args.max_batch_size)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100220
221 W = testGen.randInt(
222 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
223 )
224 # Constrict W if one dimension is too large to keep tensor size reasonable
225 if max(values_in_shape) > 5000:
226 W = testGen.randInt(0, 16)
227
228 input_shape = [values_in_shape[0], W, values_in_shape[2]]
229
230 shape_list = []
231 shape_list.append(values_in_shape.copy())
232 shape_list.append(input_shape.copy())
233
234 return shape_list
235
236 @staticmethod
237 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
238 shape = testGen.makeShape(rank)
239
240 pl, const = op["operands"]
241
242 shape_list = []
243
244 # Choose one of the inputs to broadcast
245 # Note: Simplifies OutputShaper code if we don't change first shape for errors
246 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Jerry Ge135c9552023-05-23 20:59:32 +0000247 fuzz_idx = testGen.randInt(0, rank)
248
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100249 for i in range(pl + const):
250 shape_bcast = shape.copy()
251
Jerry Ge135c9552023-05-23 20:59:32 +0000252 # To test broadcasting, the chosen fuzz index dimension should not be 1
253 if shape_bcast[fuzz_idx] == 1:
254 shape_bcast[fuzz_idx] += 1
255
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100256 # If the chosen input, pick a random index to broadcast
257 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000258 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100259 # Add one rank to the shape (or more for rank of 1)
260 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
261 shape_bcast = np.concatenate(
262 (shape_bcast, testGen.makeShape(extra_ranks))
263 )
264 if rank != 1:
265 # Either keep the extra rank, or remove it
266 new_len = testGen.rng.choice([-2, len(shape_bcast)])
267 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000268 elif error_name == ErrorIf.BroadcastShapesMismatch:
269 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100270 else:
271 shape_bcast[fuzz_idx] = 1
272
273 shape_list.append(shape_bcast)
274
275 return shape_list
276
277 @staticmethod
278 def tgConv2D(testGen, op, rank, error_name=None):
279 pl, const = op["operands"]
280
281 if error_name != ErrorIf.WrongRank:
282 assert rank == 4
283
284 # IFM dimensions are NHWC
285 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000286 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100287
288 # Constrict the overall size of the shape when creating ERROR_IF tests
289 if error_name:
290 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
291 ifm_shape, max_dim=24, max_items=10000
292 )
293
294 # Get the filter height/width from the operator parameters
295 filter_hw = op["filter"]
296
297 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000298 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100299
300 # The filter dimensions are OHWI
301 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
302
303 # The bias is OC
304 bias_shape = np.asarray([ofm_depth])
305
306 return [ifm_shape, filter_shape, bias_shape]
307
308 @staticmethod
309 def tgConv3D(testGen, op, rank, error_name=None):
310 pl, const = op["operands"]
311
312 if error_name != ErrorIf.WrongRank:
313 assert rank == 5
314
315 # IFM dimensions are NDHWC
316 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000317 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100318
319 # Constrict the overall size of the shape when creating ERROR_IF tests
320 if error_name:
321 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
322 ifm_shape, max_dim=24, max_items=10000
323 )
324
325 # Get the filter depth/height/width from the operator parameters
326 filter_dhw = op["filter"]
327
328 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000329 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100330
331 # The filter dimensions are ODHWI
332 filter_shape = np.asarray(
333 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
334 )
335
336 # The bias is OC
337 bias_shape = np.asarray([ofm_channel])
338
339 return [ifm_shape, filter_shape, bias_shape]
340
341 @staticmethod
342 def tgTransposeConv2D(testGen, op, rank, error_name=None):
343 pl, const = op["operands"]
344
345 if error_name != ErrorIf.WrongRank:
346 assert rank == 4
347
348 # IFM dimensions are NHWC
349 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000350 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100351
352 # Constrict the overall size of the shape when creating ERROR_IF tests
353 if error_name:
354 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
355 ifm_shape, max_dim=24, max_items=10000
356 )
357
358 # Get the filter height/width from the operator parameters
359 filter_hw = op["filter"]
360
361 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000362 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100363
364 # The filter dimensions are OHWI
365 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
366
367 # The bias is OC
368 bias_shape = np.asarray([ofm_depth])
369
370 return [ifm_shape, filter_shape, bias_shape]
371
372 @staticmethod
373 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
374 pl, const = op["operands"]
375
376 if error_name != ErrorIf.WrongRank:
377 assert rank == 4
378 assert pl == 1 and const == 2
379
380 # IFM dimensions are NHWC
381 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000382 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100383
384 # Constrict the overall size of the shape when creating ERROR_IF tests
385 if error_name:
386 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
387 ifm_shape, max_dim=24, max_items=10000
388 )
389
390 # Get the filter height/width from the operator parameters
391 # Filter is KH, HW, C, M
392 filter_hw = op["filter"]
393
394 # Generate a random OFM depth, but don't let it get too big because
395 # the output depth is M * C
396 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000397 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100398 ) + 1
399
400 # The filter dimensions are HWCM
401 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
402
403 # The bias is M * C
404 bias_shape = np.asarray([ifm_shape[3] * filter_m])
405
406 return [ifm_shape, filter_shape, bias_shape]
407
408 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000409 def tgFFT2d(testGen, op, rank, error_name=None):
410 pl, const = op["operands"]
411
412 if error_name != ErrorIf.WrongRank:
413 assert rank == 3
414 assert pl == 2 and const == 0
415
416 # IFM dimensions are NHW
417 ifm_shape = testGen.makeShape(rank)
418
419 # Select nearest lower power of two from input height and width
420 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
421 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
422
423 # Constrict the overall size of the shape when creating ERROR_IF tests
424 if error_name:
425 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
426
427 # Generate an invalid kernel that is not a power of two
428 if error_name == ErrorIf.KernelNotPowerOfTwo:
429 inc_h = 2 if ifm_shape[1] == 1 else 1
430 inc_w = 2 if ifm_shape[2] == 1 else 1
431 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
432 selected_inc = testGen.rng.choice(inc_choices)
433 ifm_shape[1] += selected_inc[0]
434 ifm_shape[2] += selected_inc[1]
435
436 ifm_shape = testGen.constrictBatchSize(ifm_shape)
437
438 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
439 if error_name == ErrorIf.FFTInputShapeMismatch:
440 modify_shape = testGen.rng.choice([0, 1])
441 # Only modify kernel (H, W)
442 modify_dim = testGen.rng.choice([1, 2])
443 ifm_shapes[modify_shape][modify_dim] *= 2
444
445 return [ifm_shapes[0], ifm_shapes[1]]
446
447 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000448 def tgRFFT2d(testGen, op, rank, error_name=None):
449 pl, const = op["operands"]
450
451 if error_name != ErrorIf.WrongRank:
452 assert rank == 3
453 assert pl == 1 and const == 0
454
455 # IFM dimensions are NHW
456 ifm_shape = testGen.makeShape(rank)
457
458 # Select nearest lower power of two from input height and width
459 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
460 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
461
462 # Constrict the overall size of the shape when creating ERROR_IF tests
463 if error_name:
464 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
465
466 # Generate an invalid kernel that is not a power of two
467 if error_name == ErrorIf.KernelNotPowerOfTwo:
468 # We must increment by 2 if current size is 1
469 inc_h = 2 if ifm_shape[1] == 1 else 1
470 inc_w = 2 if ifm_shape[2] == 1 else 1
471 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
472 selected_inc = testGen.rng.choice(inc_choices)
473 ifm_shape[1] += selected_inc[0]
474 ifm_shape[2] += selected_inc[1]
475
James Ward30124a82023-02-02 14:56:33 +0000476 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000477
478 return [ifm_shape]
479
480 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100481 def tgFullyConnected(testGen, op, rank, error_name=None):
482 pl, const = op["operands"]
483
484 if error_name != ErrorIf.WrongRank:
485 assert rank == 2
486
487 input_shape = testGen.makeShape(rank)
488
489 # Constrict the overall size of the shape when creating ERROR_IF tests
490 if error_name:
491 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
492
493 filter_oc = testGen.rng.integers(
494 low=testGen.args.tensor_shape_range[0],
495 high=testGen.args.tensor_shape_range[1],
496 size=1,
497 )[0]
498 filter_shape = np.asarray([filter_oc, input_shape[1]])
499
500 bias_shape = np.asarray([filter_oc])
501
502 return [input_shape, filter_shape, bias_shape]
503
504 @staticmethod
505 def tgMatmul(testGen, op, rank, error_name=None):
506 pl, const = op["operands"]
507
508 if error_name != ErrorIf.WrongRank:
509 assert rank == 3
510 assert pl == 2 and const == 0
511
512 a_shape = testGen.makeShape(rank)
513
514 # Constrict the overall size of the shape when creating ERROR_IF tests
515 if error_name:
516 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
517
518 # Get a random number for b_oc even if target shape is defined
519 b_oc = np.int32(
520 testGen.rng.integers(
521 low=testGen.args.tensor_shape_range[0],
522 high=testGen.args.tensor_shape_range[1],
523 size=1,
524 )
525 )[0]
526 # If N or H is large let b_oc be 1 to reduce output tensor size
527 if max(a_shape) > 1000:
528 b_oc = 1
529
530 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
531 return [a_shape, b_shape]
532
533 @staticmethod
534 def tgConcat(testGen, opName, rank, error_name=None):
535 pl, const = opName["operands"]
536 shape = testGen.makeShape(rank)
537
538 # Create extra tensors to concat.
539 # Take into account value of pl when getting maximum number of concats
540 num_tensors = testGen.randInt(0, 4)
541 shape_list = []
542 for i in range(pl + const + num_tensors):
543 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
544 remove = testGen.rng.choice([True, False])
545 wrongShape = shape.copy()
546
547 if remove and len(shape) > 1:
548 wrongShape = wrongShape[1:]
549 else:
550 wrongShape = list(wrongShape)
551 wrongShape.append(testGen.rng.integers(1, 10))
552
553 shape_list.append(wrongShape)
554 else:
555 shape_list.append(shape.copy())
556
557 return shape_list
558
559 @staticmethod
560 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
561 if error_name in [
562 ErrorIf.AxisSmallerZero,
563 ErrorIf.AxisLargerRank,
564 ErrorIf.ConcatInputRankMismatch,
565 ]:
566 return shapeList
567
568 # Split concat shape along axis to allow for multiple const inputs
569 # without making too many large tensors
570 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
571 # If axis can't be split we still need to invalidate other dimensions
572 if error_name == ErrorIf.ConcatInputDimMismatch:
573 for shape in shapeList[1:]:
574 # Negative test shapeLists are created individually for each test,
575 # so no need to copy the shape before altering it.
576 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
577 return shapeList
578
579 # Create copy of shape we are going to split (so we don't alter shapeList)
580 shape = shapeList[0].copy()
581 # Add original shape as first input
582 new_shapeList = [shape.copy()]
583 length_on_axis = shape[axis]
584 remaining_length = length_on_axis
585 for i in range(len(shapeList) - 2):
586 # Calculate split on axis and remaining value
587 split_shape_val = int(shape[axis] / 2)
588 remaining_length = remaining_length - split_shape_val
589
590 # Append new shape, and set remaining shape
591 shape[axis] = split_shape_val
592 new_shapeList.append(shape.copy())
593
594 # invalidate dimensions
595 if error_name == ErrorIf.ConcatInputDimMismatch:
596 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
597 else:
598 shape[axis] = remaining_length
599
600 if i == len(shapeList) - 3:
601 new_shapeList.append(shape.copy())
602
603 return new_shapeList
604
605
606class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100607 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100608
609 def __init__(self):
610 pass
611
Jeremy Johnson1271c442023-09-05 11:39:26 +0100612 class TVGInfo:
613 """Enhanced tensor values information including data gen dict."""
614
615 def __init__(self, tensorList, dataGenDict):
616 self.tensorList = tensorList
617 self.dataGenDict = dataGenDict
618
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100619 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000620 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100621 pCount, cCount = op["operands"]
622
623 tens = []
624 tens.extend(
625 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
626 )
627 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
628
629 return tens
630
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100631 # Default high value for random numbers
632 TVG_FLOAT_HIGH_VALUE = {
633 DType.FP32: (1 << 128) - (1 << (127 - 23)),
634 DType.FP16: (1 << 16) - (1 << (15 - 10)),
635 DType.BF16: (1 << 128) - (1 << (127 - 7)),
636 }
637
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100638 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000639 def _get_data_range(testGen, dtype, highValueLookup):
640 if dtype in highValueLookup:
641 data_range = testGen.getDTypeRange(dtype, high_inclusive=True)
642 high_val = highValueLookup[dtype]
643 # Set the values to something that won't produce infinity whilst
644 # respecting the default ranges if less than the high value
645 return [
646 max(-high_val, data_range[0]),
647 min(high_val, data_range[1]),
648 ]
649 return None
650
651 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100652 def tvgLazyGenDefault(
653 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
654 ):
655 # Variable inputs versus constants
656 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
657
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100658 if (
659 error_name is not None
660 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100661 or "data_gen" not in testGen.TOSA_OP_LIST[opName]
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100662 ):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100663 # Fall back to original path when dealing with unsupported types or ops
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100664
665 # First turn off lazy data gen so we always produce data
666 lazy_data_gen = testGen.args.lazy_data_gen
Jeremy Johnson1271c442023-09-05 11:39:26 +0100667 testGen.args.lazy_data_gen = False
668
Jeremy Johnson1271c442023-09-05 11:39:26 +0100669 tens_ser_list = TosaTensorValuesGen.tvgDefault(
670 testGen,
671 testGen.TOSA_OP_LIST[opName],
672 dtypeList,
673 shapeList,
674 [],
675 error_name,
676 )
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100677 # Restore lazy data gen setting
678 testGen.args.lazy_data_gen = lazy_data_gen
Jeremy Johnson1271c442023-09-05 11:39:26 +0100679 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
680
681 # Create data generator meta-data
682 dg_type = argsDict["dg_type"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100683 tens_data = {
684 "version": "0.1",
685 "tensors": {},
686 }
687 dg_tens_meta = tens_data["tensors"]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100688 tens_ser_list = []
689 for idx, shape in enumerate(shapeList):
690
691 tens_meta = {}
692 tens_meta["generator"] = gtu.DataGenType(dg_type).name
693 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
694 tens_meta["shape"] = [int(i) for i in shape]
695 tens_meta["input_pos"] = idx
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100696 tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100697
698 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100699 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100700 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100701 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100702
703 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
704 info = {}
705 # TODO - generate seed for this generator based on test
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100706 info["rng_seed"] = 42
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100707 if "data_range" in argsDict:
708 data_range = argsDict["data_range"]
709 else:
710 data_range = testGen.getDTypeRange(
711 dtypeList[idx], high_inclusive=True
712 )
713 info["range"] = [str(v) for v in data_range]
Jeremy Johnson1271c442023-09-05 11:39:26 +0100714 tens_meta["pseudo_random_info"] = info
715 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
716 info = {}
717 info["s"] = argsDict["s"]
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100718 info["ks"] = int(argsDict["ks"])
719 if "acc_type" in argsDict:
720 # Convert type number into JSON name
721 info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
722 "json"
723 ]
724 if "kernel" in argsDict:
725 info["kernel"] = [int(k) for k in argsDict["kernel"]]
726 if "axis" in argsDict:
727 info["axis"] = int(argsDict["axis"])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100728 tens_meta["dot_product_info"] = info
729 else:
730 # TODO - other data gen type
731 assert False, "TODO: support other data gen types"
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100732
733 # Using the finished generate config meta data - generate the data if
734 # needed and assign a tensor name from the serializer
735
736 # Need to generate data when not lazy or for the bias tensor as we need
737 # to work out if the bias data is non-zero for compliance
738 if not testGen.args.lazy_data_gen or (
739 idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
740 ):
741 # Give this tensor a temporary name until we get one from the serializer
742 temp_name = f"placeholder_{idx}"
743 dg_tens_meta[temp_name] = tens_meta
744 # Create data now using the temporary name to access meta details
745 data = testGen.dgl.get_tensor_data(temp_name, tens_data)
746 # Remove the item as we will give it the correct name later
747 del dg_tens_meta[temp_name]
748
749 if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
750 # The KS value used by compliance verification is altered when the
751 # bias data is non-zero
752 if max(abs(data)) > 0.0:
753 argsDict["ksb"] = argsDict["ks"] + 1
754
755 if testGen.args.lazy_data_gen:
756 data = None
757
758 if tens_meta["input_type"] == "VARIABLE":
759 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
760 else:
761 tens = testGen.ser.addConst(shape, dtypeList[idx], data)
762
763 tens_ser_list.append(tens)
764 # Add the meta data to the list using the serializer tensor name
Jeremy Johnson1271c442023-09-05 11:39:26 +0100765 dg_tens_meta[tens.name] = tens_meta
766
Jeremy Johnson1271c442023-09-05 11:39:26 +0100767 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
768
769 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000770 def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100771 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100772 pCount, cCount = op["operands"]
773 assert (
774 pCount == 1 and cCount == 0
775 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100776 # Must create tensors with values within accumulator (int32) negatable
777 # range
778 max_val = (1 << 31) - 1
779 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100780 arr = np.int32(
781 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
782 )
783 placeholders = []
784 placeholders.append(
785 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
786 )
787 return placeholders
788 else:
789 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000790 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100791 )
792
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000793 # Set the data range to half the largest value
794 TVG_FLOAT_HIGH_VALUE_ADDSUB = {
795 DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
796 DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
797 DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2),
798 }
799
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100800 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000801 def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100802 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000803 # Make sure the integer operation does not cause value saturation - where
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100804 # the number wraps due to limited number of bits to store the answer
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000805 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100806 pCount, cCount = op["operands"]
807 assert (
808 pCount == 2 and cCount == 0
809 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000810 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100811 add = op["op"] == Op.ADD
812 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
813 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
814 if add:
815 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
816 else:
817 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
818
819 # Work out the saturation limits
820 max_i32 = (1 << 31) - 1
821 min_i32 = -(1 << 31)
822 max_arr = np.full(shapeList[1], max_i32)
823 min_arr = np.full(shapeList[1], min_i32)
824
825 # Find how much values exceed the maximum/minimums
826 sat_max_arr = np.maximum(res_arr - max_arr, 0)
827 sat_min_arr = np.minimum(res_arr - min_arr, 0)
828
829 if not add:
830 # Swap saturation values and negate values as we need to perform opposite operations
831 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
832
833 # Create new array of unsaturated values by clipping values as needed
834 b_unsat_arr = b_arr
835 if (sat_max_arr != 0).any():
836 # Clip values that cause saturation
837 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
838 # Reduce axes in unsaturated tensor to match original tensor
839 for axis, dim in enumerate(b_arr.shape):
840 if dim != b_unsat_arr.shape[axis]:
841 assert (
842 dim == 1
843 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
844 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
845
846 if (sat_min_arr != 0).any():
847 # Clip values that cause saturation
848 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
849 # Reduce axes in unsaturated tensor to match original tensor
850 for axis, dim in enumerate(b_arr.shape):
851 if dim != b_unsat_arr.shape[axis]:
852 assert (
853 dim == 1
854 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
855 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
856
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000857 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100858 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
859 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000860 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100861 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
862 )
863
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000864 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100865 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000866 # ERROR_IF or floating point test
867 data_range = TosaTensorValuesGen._get_data_range(
868 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB
869 )
870 if data_range:
871 argsDict["data_range"] = data_range
872
873 return TosaTensorValuesGen.tvgLazyGenDefault(
874 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100875 )
876
877 @staticmethod
878 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000879 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100880 ):
881 if dtypeList[0] in (
882 DType.INT32,
883 DType.INT16,
884 DType.INT8,
885 ):
886 # Limit input tensors with cond_if_binary or while_loop to stop
887 # saturation of add/sub ops with int32 and keep all logical shift
888 # values between 0 to 31 for int16 or int8
889 pCount, cCount = op["operands"]
890 pRemain = pCount
891 placeholders = []
892 for idx, shape in enumerate(shapeList[:]):
893 if dtypeList[0] == DType.INT32:
894 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
895 else:
896 arr = np.int32(
897 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
898 )
899 if pRemain > 0:
900 placeholders.append(
901 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
902 )
903 pRemain -= 1
904 else:
905 placeholders.append(
906 testGen.ser.addConst(shape, dtypeList[idx], arr)
907 )
908
909 return placeholders
910 else:
911 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000912 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100913 )
914
915 @staticmethod
916 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000917 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100918 ):
919 pCount, cCount = op["operands"]
920 # Force value of operand[1] to be within [0, num_bits]
921 assert (
922 pCount == 2 and cCount == 0
923 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
924
925 placeholders = []
926 for idx, shape in enumerate(shapeList[:]):
927 if idx == 1:
928 if dtypeList[idx] == DType.INT8:
929 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
930 elif dtypeList[idx] == DType.INT16:
931 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
932 elif dtypeList[idx] == DType.INT32:
933 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
934 elif error_name == ErrorIf.WrongInputType:
935 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
936 else:
937 raise Exception("OpArithmeticRightShift: invalid input dtype")
938 else:
939 arr = testGen.getRandTensor(shape, dtypeList[idx])
940 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
941
942 return placeholders
943
944 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000945 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100946 # Set datatype of condition tensor to boolean
947 dtypeList[0] = DType.BOOL
948
949 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000950 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100951 )
952
953 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000954 def tvgIntDiv(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100955 if error_name is None:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000956 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100957 pCount, cCount = op["operands"]
958 assert (
959 pCount == 2 and cCount == 0
960 ), "Op.INTDIV must have 2 placeholders, 0 consts"
961
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000962 tens_ser_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100963
964 # Two invalid cases for Op.INTDIV:
965 # 1. divisor == 0
966 # 2. dividend == -(1<<31) and divisor == -1
967 while True:
968 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
969 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
970
971 if (divisor_arr == 0).any():
972 continue
973
974 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
975 continue
976
977 break
978
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000979 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100980 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
981 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000982 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100983 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
984 )
985
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000986 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100987 else:
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +0000988 return TosaTensorValuesGen.tvgLazyGenDefault(
989 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100990 )
991
Jeremy Johnsona4d907e2023-10-26 13:53:14 +0100992 # Set the data range to the square root of the largest value
993 TVG_FLOAT_HIGH_VALUE_MUL = {
994 DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
995 DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
996 DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
997 }
998
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100999 @staticmethod
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001000 def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
1001 if error_name is not None or dtypeList[0] in (
1002 DType.FP16,
1003 DType.BF16,
1004 DType.FP32,
1005 ):
1006 # ERROR_IF or floating point test
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001007 data_range = TosaTensorValuesGen._get_data_range(
1008 testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL
1009 )
1010 if data_range:
1011 argsDict["data_range"] = data_range
1012
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001013 return TosaTensorValuesGen.tvgLazyGenDefault(
1014 testGen, opName, dtypeList, shapeList, argsDict, error_name
1015 )
1016 else:
1017 # Integer test
1018 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001019 pCount, cCount = op["operands"]
1020 assert (
1021 pCount == 2 and cCount == 0
1022 ), "Op.MUL must have 2 placeholders, 0 consts"
1023
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001024 tens_ser_list = []
1025
1026 # Make sure multiply result in int32 range
1027 shift = argsDict["shift"]
1028 if dtypeList[0] == DType.INT8:
1029 num_bits = 8
1030 elif dtypeList[0] == DType.INT16:
1031 num_bits = 16
1032 elif dtypeList[0] == DType.INT32:
1033 num_bits = 32
1034 elif error_name == ErrorIf.WrongInputType:
1035 num_bits = 8
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001036 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001037 raise Exception("OpMul: invalid input dtype")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001038
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001039 for idx, shape in enumerate(shapeList[:]):
1040 low = -(2 ** (num_bits - 1))
1041 high = (2 ** (num_bits - 1)) - 1
1042
1043 a_arr = np.int32(
1044 testGen.rng.integers(low=low, high=high, size=shapeList[0])
1045 )
1046 b_arr = np.int32(
1047 testGen.rng.integers(low=low, high=high, size=shapeList[1])
1048 )
1049
1050 i = 0
1051 while True:
1052
1053 a_arr_64 = a_arr.astype(np.int64)
1054 b_arr_64 = b_arr.astype(np.int64)
1055
1056 if shift > 0:
1057 rounding = 1 << (shift - 1)
1058 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001059 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001060 result_arr = a_arr_64 * b_arr_64
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001061
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001062 if (result_arr > -(2**31)).all() and (
1063 result_arr <= ((2**31) - 1)
1064 ).all():
1065 break
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001066
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001067 i = i + 1
1068 a_arr = a_arr // 2
1069 b_arr = b_arr // 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001070
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001071 tens_ser_list.append(
1072 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001073 )
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01001074 tens_ser_list.append(
1075 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1076 )
1077
1078 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001079
1080 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001081 def tvgConcat(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001082 count = len(shapeList) - testGen.args.num_const_inputs_concat
1083 if count < 1:
1084 count = 1
1085 if testGen.args.num_const_inputs_concat == 0:
1086 count = len(shapeList)
1087
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001088 shapeList = TosaTensorGen.tgConcatConstInput(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001089 testGen, shapeList, argsDict["axis"], error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001090 )
1091
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001092 tens_ser_list = []
1093 tens_ser_list.extend(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001094 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
1095 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001096 tens_ser_list.extend(
1097 testGen.buildConstTensors(shapeList[count:], dtypeList[count:])
1098 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001099
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001100 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001101
1102 @staticmethod
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001103 def tvgLogicalShift(
1104 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
1105 ):
1106 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001107 pCount, cCount = op["operands"]
1108 assert (
1109 pCount == 2 and cCount == 0
1110 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
1111 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1112 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001113 tens_ser_list = []
1114 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001115 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1116 )
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001117 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001118 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1119 )
1120
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001121 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001122
1123 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001124 def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001125 if error_name is None:
1126 pCount, cCount = op["operands"]
1127 assert (
1128 pCount == 2 and cCount == 0
1129 ), "Op.EQUAL must have 2 placeholders, 0 consts"
1130 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1131 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
1132 # Using random numbers means that it will be very unlikely that
1133 # there are any matching (equal) values, therefore force that
1134 # there are twice the number of matching values as the tensor rank
1135 for num in range(0, len(shapeList[0]) * 2):
1136 a_index = []
1137 b_index = []
1138 # Choose an index in each axis for the whole shape
1139 for axis in range(0, len(shapeList[0])):
1140 # Index can be up to the largest dimension in both shapes
1141 index = np.int32(
1142 testGen.rng.integers(
1143 0, max(shapeList[0][axis], shapeList[1][axis])
1144 )
1145 )
1146 # Reduce the index down to a shape's dim for broadcasting
1147 a_index.append(min(shapeList[0][axis] - 1, index))
1148 b_index.append(min(shapeList[1][axis] - 1, index))
1149
1150 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1151
1152 placeholders = []
1153 placeholders.append(
1154 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1155 )
1156 placeholders.append(
1157 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1158 )
1159 return placeholders
1160 else:
1161 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001162 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001163 )
1164
1165 @staticmethod
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001166 def tvgReduceSum(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001167 if dtypeList[0] == DType.INT32:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001168 op = testGen.TOSA_OP_LIST[opName]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001169 pCount, cCount = op["operands"]
1170 assert (
1171 pCount == 1 and cCount == 0
1172 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1173 # Limit values so that the sum cannot exceed the range of an int32 during
1174 # summation of any axis
1175 range_val = int((1 << 31) / max(shapeList[0]))
1176 values_arr = np.int32(
1177 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1178 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001179 tens_ser_list = []
1180 tens_ser_list.append(
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001181 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1182 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001183 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001184 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001185 # ERROR_IF or dot product floating point test
1186 return TosaTensorValuesGen.tvgLazyGenDefault(
1187 testGen, opName, dtypeList, shapeList, argsDict, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001188 )
1189
1190
1191class TosaArgGen:
1192 """Argument generators create exhaustive or random lists of attributes for
1193 operators that take attributes or other parameters.
1194
1195 The return value is a list of (descriptive_name, [arglist]) tuples where
1196 the descriptive_name is appended to the test name and the arglist is expanded
1197 as arguments to the operator build function.
1198 """
1199
1200 def __init__(self):
1201 pass
1202
1203 @staticmethod
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001204 def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001205 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001206 if (
1207 error_name is None
1208 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1209 and gtu.dtypeIsSupportedByCompliance(dtype)
1210 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001211 if dtype in [DType.FP16, DType.FP32, DType.BF16]:
1212 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1213 else:
1214 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1215 else:
1216 # Error test or No data generator types listed - assume random
1217 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1218
1219 # Expand arg list with other data generator types
1220 new_arg_list = []
1221 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001222 for arg_str, args_dict in arg_list:
1223 args_dict["dg_type"] = dg_type
Jeremy Johnson1271c442023-09-05 11:39:26 +01001224 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
1225 # Default test
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001226 new_arg_list.append((arg_str, args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001227
1228 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1229 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001230 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001231 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001232 shape_info = (
1233 " ({})".format(testGen.shapeStr(args_dict["shape"]))
1234 if "shape" in args_dict
1235 else ""
1236 )
Jeremy Johnson1271c442023-09-05 11:39:26 +01001237 print(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001238 f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001239 )
1240 continue
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001241 # KS and acc_type is required by all dot product generators
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001242 assert "ks" in args_dict
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001243 assert "acc_type" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001244
1245 for s in testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS:
1246 new_arg_str = f"{arg_str}_s{s}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001247 new_args_dict = args_dict.copy()
1248 new_args_dict["s"] = s
1249 new_arg_list.append((new_arg_str, new_args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001250
1251 return new_arg_list
1252
1253 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001254 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1255 """A trivial argument generator for operators that don't take any
1256 non-tensor arguments"""
Jeremy Johnson7bf0cb92023-10-31 14:37:54 +00001257 arg_list = TosaArgGen._add_data_generators(
1258 testGen,
1259 opName,
1260 dtype,
1261 [("", {})],
1262 error_name,
1263 )
1264 # Return list of tuples: (arg_str, args_dict)
1265 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001266
1267 @staticmethod
1268 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1269 """Build the axis argument for operators that take a single axis"""
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001270 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001271 shape = shapeList[0]
1272
1273 if error_name == ErrorIf.AxisSmallerZero:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001274 # Set too small axis
1275 axes = [testGen.rng.integers(-5, 0)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001276 elif error_name == ErrorIf.AxisLargerRank:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001277 # Set too large axis
1278 axes = [testGen.rng.integers(len(shape) + 1, len(shape) + 10)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001279 else:
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001280 # Create tests for each dimension
1281 axes = range(0, len(shape))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001282
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001283 opid = testGen.TOSA_OP_LIST[opName]["op"]
1284
1285 for a in axes:
1286 args_dict = {"axis": int(a)}
1287 if opid == Op.REDUCE_SUM:
1288 args_dict["dot_products"] = gtu.product(shape)
1289 args_dict["shape"] = shape
1290 args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
1291 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32
1292
1293 arg_list.append(("axis{}".format(a), args_dict))
1294
1295 arg_list = TosaArgGen._add_data_generators(
1296 testGen,
1297 opName,
1298 dtype,
1299 arg_list,
1300 error_name,
1301 )
1302 # Return list of tuples: (arg_str, args_dict)
1303 return arg_list
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001304
1305 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001306 def _calculate_sparsity(num_tests, sparsity_factor):
1307 sparsity = num_tests // sparsity_factor + 1
1308 # If there are only a small number of tests, just select them all
1309 if sparsity < 13:
1310 sparsity = 1
1311 # To get a variety of parameter combinations sparsity should not be a
1312 # multiple of 2, 3 or 5
1313 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1314 sparsity += 1
1315 return sparsity
1316
1317 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001318 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001319 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001320 arg_list = []
1321
Jeremy Johnson0c716862023-04-13 17:18:19 +01001322 if testGen.args.level8k and error_name is not None:
1323 # Don't produce negative large tests
1324 return arg_list
1325
1326 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001327 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001328 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001329 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001330
Jeremy Johnson1271c442023-09-05 11:39:26 +01001331 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001332
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001333 # Op type checks
Jeremy Johnson0c716862023-04-13 17:18:19 +01001334 conv3d = opName.startswith("conv3d")
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001335 depthwise = opName.startswith("depthwise")
1336
1337 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001338 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001339 if error_name != ErrorIf.WrongRank:
1340 assert len(ifm_shape) == rank
1341 assert len(filter_shape) == rank
1342
Jeremy Johnson0c716862023-04-13 17:18:19 +01001343 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001344 k_rank = rank - 2
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001345 k_pos = 0 if depthwise else 1
Jeremy Johnson0c716862023-04-13 17:18:19 +01001346 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001347 # compliance size - KS
1348 k_size = gtu.product(k_shape)
1349 if not depthwise:
1350 k_size *= ifm_shape[-1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001351
Jeremy Johnson0c716862023-04-13 17:18:19 +01001352 if not testGen.args.level8k:
1353 # Generate comprehensive argument lists
1354 # - except for named errors, which use specific invalid value(s)
1355 if error_name == ErrorIf.PadSmallerZero:
1356 p_vals = [testGen.rng.choice(range(-5, 0))]
1357 else:
1358 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1359 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1360 if error_name == ErrorIf.StrideSmallerOne:
1361 # Can't use stride=0, as it is used to derive output shape, as a divisor
1362 s_vals = [testGen.rng.choice(range(-5, 0))]
1363 else:
1364 # Stride must be greater than 1 to force non-integer error
1365 startStride = (
1366 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001367 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001368 s_vals = [
1369 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1370 ]
1371 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1372 if error_name == ErrorIf.DilationSmallerOne:
1373 d_vals = [testGen.rng.choice(range(-5, 1))]
1374 else:
1375 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1376 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001377
Jeremy Johnson0c716862023-04-13 17:18:19 +01001378 if not error_name and testGen.args.oversize:
1379 # add some oversize argument values
1380 if max(ifm_shape) < 64:
1381 bigPadding = 9
1382 paddings.update(
1383 {
1384 x
1385 for x in itertools.product(
1386 *([[0, bigPadding]] * (k_rank * 2))
1387 )
1388 }
1389 )
1390 bigStride = 8
1391 strides.update(
1392 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1393 )
1394 bigDilation = 7
1395 dilations.update(
1396 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1397 )
1398 max_dim_size = None
1399
1400 # There are too many parameter combinations, so generate them sparsely,
1401 # very sparse for negative tests
1402 sparsity_factor = 2 if error_name else 120
1403 sparsity = TosaArgGen._calculate_sparsity(
1404 len(paddings) * len(strides) * len(dilations), sparsity_factor
1405 )
1406 else:
1407 # Only test 8k levels boundaries
1408 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1409 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1410 bigPadding = bigKernel
1411
1412 dilation_shape = [1] * k_rank
1413 pad_shape = [0] * k_rank * 2
1414 if conv3d:
1415 # Small stride apart from for big kernel (see below) to keep
1416 # tensor size/calculation small
1417 stride_shape = [1] * k_rank
1418 for idx in range(k_rank):
1419 pad_offset = idx * 2
1420 if k_shape[idx] == bigKernel:
1421 # Padding shape needs to account for tensor shape
1422 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1423 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1424 # Big stride to reduce output size
1425 stride_shape[idx] = bigKernel
1426 else:
1427 # Account for kernel size
1428 pad_shape[pad_offset] = k_shape[idx] - 1
1429 else:
1430 # Always have a large stride with extra padding and dilation to keep
1431 # tensor calculation reasonable
1432 stride_shape = [bigKernel] * k_rank
1433 for idx in range(k_rank):
1434 # Dilation shape must account for kernel size
1435 dilation_shape[idx] = bigKernel // k_shape[idx]
1436 # Padding shape needs to accommodate tensor/kernel & dilation
1437 pad_offset = idx * 2
1438 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1439 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1440
1441 strides = {tuple(stride_shape)}
1442 dilations = {tuple(dilation_shape)}
1443 paddings = {tuple(pad_shape)}
1444 # Create a limit for the output dimensions size
1445 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1446
1447 # Currently allow all combinations that are reasonable size
1448 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001449
1450 n = 0
1451 for s in sorted(list(strides)):
1452 for p in sorted(list(paddings)):
1453 for d in sorted(list(dilations)):
1454 if (
1455 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001456 # the padded shape must exceed the dilation * kernel to get a positive
1457 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001458 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
1459 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01001460 and (
1461 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01001462 or (
1463 (ifm_shape[3] - 1 + p[4] + p[5])
1464 > d[2] * (k_shape[2] - 1)
1465 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01001466 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001467 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001468 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01001469 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001470 for index in range(k_rank):
1471 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001472 partial = (
1473 ifm_shape[index + 1]
1474 - 1
1475 + p[pad_offset]
1476 + p[pad_offset + 1]
1477 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001478 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001479 remainders.append(partial % s[index])
1480 outputs.append((partial // s[index]) + 1)
1481
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001482 if (
1483 # the parameters must produce integer exact output
1484 error_name != ErrorIf.ConvOutputShapeNonInteger
1485 and max(remainders) == 0
1486 ) or (
1487 error_name == ErrorIf.ConvOutputShapeNonInteger
1488 and max(remainders) > 0
1489 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001490 if (
1491 max_dim_size is not None
1492 and max(outputs) >= max_dim_size
1493 ):
1494 # Test will consume too much memory - skip it
1495 continue
1496
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001497 # Compliance - number of dot product calculations
1498 if depthwise:
1499 # TODO - add support
1500 dots = 0
1501 else:
1502 dots = gtu.product(
1503 (ifm_shape[0], *outputs, filter_shape[0])
1504 )
1505 args_dict = {
1506 "acc_type": accum_dtype,
1507 "stride": s,
1508 "pad": p,
1509 "dilation": d,
1510 "kernel": k_shape,
1511 "ks": k_size,
1512 "dot_products": dots,
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001513 "shape": ifm_shape,
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001514 }
1515
Jeremy Johnson0c716862023-04-13 17:18:19 +01001516 # Support for larger values than 9 needs different delimiter
1517 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001518 arg_list.append(
1519 (
James Ward8b390432022-08-12 20:48:56 +01001520 "acc{}_st{}_pad{}_dilat{}".format(
1521 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001522 delim.join([str(x) for x in s]),
1523 delim.join([str(x) for x in p]),
1524 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001525 ),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001526 args_dict,
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001527 )
1528 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001529 n += 1
1530
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01001531 arg_list = TosaArgGen._add_data_generators(
1532 testGen,
1533 opName,
1534 dtypes[0],
1535 arg_list,
1536 error_name,
1537 )
1538 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001539 return arg_list
1540
1541 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001542 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1543
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001544 assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected"
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001545 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01001546
1547 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001548 accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01001549 elif error_name == ErrorIf.WrongInputType:
1550 # Pick some potentially correct output dtype if input type is incorrect
1551 accum_dtype = DType.INT32
1552 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001553 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001554
Jeremy Johnsonaee62af2023-11-02 17:16:25 +00001555 # Set up compliance info
1556 args_dict = {
1557 "acc_type": accum_dtype,
1558 "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC)
1559 "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])),
1560 "shape": shapeList[0],
1561 }
1562
1563 arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)]
1564
1565 arg_list = TosaArgGen._add_data_generators(
1566 testGen,
1567 opName,
1568 input_dtype,
1569 arg_list,
1570 error_name,
1571 )
1572 # Return list of tuples: (arg_str, args_dict)
1573 return arg_list
James Ward8b390432022-08-12 20:48:56 +01001574
1575 @staticmethod
1576 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
1577 # Get valid accumulate type(s)
1578 if dtype == DType.INT8:
1579 accum_dtypes = [DType.INT32]
1580 elif dtype == DType.INT16:
1581 accum_dtypes = [DType.INT48]
1582 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001583 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001584 elif dtype == DType.BF16:
1585 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001586 elif dtype == DType.FP32:
1587 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001588 elif error_name is None:
1589 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
1590
1591 if error_name == ErrorIf.WrongOutputType:
1592 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson1271c442023-09-05 11:39:26 +01001593 accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01001594 elif error_name == ErrorIf.WrongInputType:
1595 # Pick some potentially correct output dtype if input type is incorrect
1596 accum_dtypes = [DType.INT32]
1597
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001598 # Set up compliance info
1599 args_dict = {
1600 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
1601 # Set dot_products = N*H*W
1602 "dot_products": gtu.product(
1603 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
1604 ),
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001605 "shape": shapeList[0],
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001606 }
1607
1608 # Create arg tuple of string and dict
1609 arg_list = []
1610 for a in accum_dtypes:
1611 d = args_dict.copy()
1612 d["acc_type"] = a
1613 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001614
1615 arg_list = TosaArgGen._add_data_generators(
1616 testGen,
1617 opName,
1618 dtype,
1619 arg_list,
1620 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001621 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001622 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001623 return arg_list
James Ward8b390432022-08-12 20:48:56 +01001624
1625 @staticmethod
1626 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001627 arg_list = []
1628
Jeremy Johnson0c716862023-04-13 17:18:19 +01001629 if testGen.args.level8k and error_name is not None:
1630 # Don't produce negative large tests
1631 return arg_list
1632
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001633 ifm_shape = shapeList[0]
1634 filter_shape = shapeList[1]
1635
Jeremy Johnson1271c442023-09-05 11:39:26 +01001636 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001637
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001638 # Must be rank 4
1639 if error_name != ErrorIf.WrongRank:
1640 assert len(ifm_shape) == 4
1641 assert len(filter_shape) == 4
1642
Jeremy Johnson0c716862023-04-13 17:18:19 +01001643 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001644
Jeremy Johnson0c716862023-04-13 17:18:19 +01001645 if not testGen.args.level8k:
1646 # Generate comprehensive argument lists
1647 # - except for named errors, which use specific invalid value(s)
1648 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
1649 if error_name == ErrorIf.PadLargerEqualKernel:
1650 max_filter_size = -max(k_shape[0], k_shape[1])
1651 p_vals = [
1652 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
1653 ]
1654 else:
1655 p_vals = [
1656 x
1657 for x in range(
1658 smallest_padding_size, testGen.args.max_conv_padding + 1
1659 )
1660 ]
1661 paddings = {x for x in itertools.product(*([p_vals] * 4))}
1662 if error_name == ErrorIf.StrideSmallerOne:
1663 # Can't use stride=0, as it is used to derive output shape, as a divisor
1664 s_vals = [testGen.rng.choice(range(-5, 0))]
1665 else:
1666 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
1667 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001668
Jeremy Johnson0c716862023-04-13 17:18:19 +01001669 if not error_name and testGen.args.oversize:
1670 # add some oversize argument values
1671 if max(ifm_shape) < 64:
1672 bigPadding = 9
1673 paddings.update(
1674 {
1675 x
1676 for x in itertools.product(
1677 *([[smallest_padding_size, bigPadding]] * 4)
1678 )
1679 }
1680 )
1681 bigStride = 8
1682 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
1683
1684 # There are too many parameter combinations, so generate them sparsely,
1685 # very sparse for negative tests
1686 sparsity_factor = 2 if error_name else 10
1687 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
1688 # If there are only a small number of tests, just select them all
1689 if sparsity < 13:
1690 sparsity = 1
1691 # To get a variety of parameter combinations sparsity should not be a
1692 # multiple of 2, 3 or 5
1693 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1694 sparsity += 1
1695 else:
1696 # Only test 8k levels boundaries
1697 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1698 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1699 bigPadding = bigKernel
1700
1701 pad_shape = [0] * (len(k_shape) * 2)
1702 stride_shape = [1] * len(k_shape)
1703 # The point at which input dimension combined with the stride will
1704 # create large output sizes!
1705 LARGE_SIZE = 2
1706 for idx in range(len(k_shape)):
1707 pad_offset = idx * 2
1708 if k_shape[idx] == bigKernel:
1709 # Set large stride
1710 stride_shape[idx] = bigKernel
1711 # Use negative output padding to reduce shape size
1712 pad_shape[pad_offset] = -(bigPadding - 1)
1713 if ifm_shape[idx + 1] > LARGE_SIZE:
1714 pad_shape[pad_offset + 1] = -(bigPadding - 1)
1715 else:
1716 # The other dimension should be the bigKernel
1717 alt_idx = 1 - idx
1718 if (
1719 k_shape[alt_idx] == bigKernel
1720 and ifm_shape[alt_idx + 1] < LARGE_SIZE
1721 ):
1722 # As the input is small, the large stride won't
1723 # affect the output so we can add some padding
1724 pad_shape[pad_offset + 1] = bigPadding
1725
1726 strides = {tuple(stride_shape)}
1727 paddings = {tuple(pad_shape)}
1728
1729 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001730 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001731
1732 n = 0
1733 for s in sorted(list(strides)):
1734 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07001735 if n % sparsity == 0:
1736 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001737 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
1738 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07001739 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001740
1741 # Support for larger values than 9 needs different delimiter
1742 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07001743 arg_list.append(
1744 (
James Ward8b390432022-08-12 20:48:56 +01001745 "acc{}_st{}_pad{}_os{}".format(
1746 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001747 delim.join([str(x) for x in s]),
1748 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07001749 "x".join([str(x) for x in os]),
1750 ),
James Ward8b390432022-08-12 20:48:56 +01001751 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001752 )
TatWai Chong24594f52022-06-08 00:48:04 -07001753 )
1754 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001755
1756 return arg_list
1757
1758 @staticmethod
1759 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001760 rank = len(shapeList[0])
1761
1762 # Exhaustively test combinations of padding on each side of each dimension
1763 # - the range of padding values is defined by pad_min and pad_max
1764 # - for padding >9, the name format needs to be more distinctive
1765 pad_min, pad_max = 0, 1
1766 pad_values = [x for x in range(pad_min, pad_max + 1)]
1767 if error_name == ErrorIf.PadSmallerZero:
1768 pad_values = [x for x in range(-2, 0)]
1769 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
1770 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
1771
1772 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
1773 pad_const_int = testGen.getRandNumberDType(dtype)
1774 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00001775 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001776 pad_const_int = 0
1777 pad_const_fp = testGen.getRandNumberDType(dtype)
1778 else:
1779 return []
1780
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001781 list_shape_pad_values = list(shape_pad_values)
1782 # If we are producing tests for rank 6 or greater use sparsity
1783 if len(list_shape_pad_values) > 1024:
1784 sparsity_factor = 2 if error_name else 120
1785 sparsity = TosaArgGen._calculate_sparsity(
1786 len(list_shape_pad_values), sparsity_factor
1787 )
1788 else:
1789 sparsity = 1
1790
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001791 # Build arg list
1792 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001793 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01001794 paddings = list(paddings)
1795 args_valid = True
1796
1797 if error_name == ErrorIf.PadSmallerZero:
1798 # Prevent negative output shapes while ensuring still testing for negative padding
1799 for i in range(rank):
1800 dim_after_padding = (
1801 paddings[i][0] + paddings[i][1] + shapeList[0][i]
1802 )
1803 if dim_after_padding < 1:
1804 paddings[i] = (0, 0)
1805 if all([p > -1 for p in paddings[i]]):
1806 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001807 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01001808 name = "pad"
1809 for r in range(rank):
1810 before, after = paddings[r]
1811 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001812 args_dict = {
1813 "pad": np.array(paddings),
1814 "pad_const_int": pad_const_int,
1815 "pad_const_fp": pad_const_fp,
1816 }
1817 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01001818
1819 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
1820 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001821
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001822 arg_list = TosaArgGen._add_data_generators(
1823 testGen,
1824 opName,
1825 dtype,
1826 arg_list,
1827 error_name,
1828 )
1829
1830 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001831 return arg_list
1832
1833 @staticmethod
1834 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
1835 arg_list = []
1836
1837 shape = shapeList[0]
1838 if error_name != ErrorIf.WrongRank:
1839 assert len(shape) == 4
1840
Jeremy Johnson0c716862023-04-13 17:18:19 +01001841 test_level8k = testGen.args.level8k and error_name is None
1842
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001843 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001844 startKernel = 2
1845 startPad = 0
1846 if not test_level8k:
1847 # Generate comprehensive argument lists
1848 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
1849 paddings = {x for x in itertools.product(*([p_vals] * 4))}
1850 # Stride must be greater than 1 to force non-integer error
1851 s_vals = [
1852 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
1853 ]
1854 strides = {x for x in itertools.product(*([s_vals] * 2))}
1855 k_vals = [
1856 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
1857 ]
1858 kernels = {x for x in itertools.product(*([k_vals] * 2))}
1859 max_dim_size = None
1860 else:
1861 # Only test 8k levels
1862 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1863 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1864 strides = {(1, bigStride), (bigStride, 4)}
1865 kernels = {(1, bigKernel), (bigKernel, 3)}
1866 paddings = set()
1867 for s in sorted(list(strides)):
1868 for k in sorted(list(kernels)):
1869 padding = []
1870 for idx in range(len(k)):
1871 total_padding = s[idx] - shape[idx + 1] + k[idx]
1872 while total_padding < 0:
1873 # Must meet: shape + padding > kernel
1874 total_padding += s[idx]
1875 if total_padding < k[idx]:
1876 padding.extend([0, total_padding])
1877 else:
1878 # Note this may produce padding >= k[idx] which is not
1879 # allowed - but will be ignored in the creation loop below
1880 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
1881 paddings.add(tuple(padding))
1882 # Create a limit for the output dimensions size
1883 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001884
James Ward8b390432022-08-12 20:48:56 +01001885 if opName == "max_pool2d":
1886 accum_dtypes = [None] # max_pool has no accumulate dtype
1887 elif dtype == DType.INT8 or dtype == DType.INT16:
1888 accum_dtypes = [DType.INT32]
1889 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001890 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001891 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001892 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001893 elif error_name is None:
1894 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
1895 else:
1896 # Set to something for the ErrorIf case which has
1897 # incorrect input data-type
1898 accum_dtypes = [DType.INT32]
1899
Jeremy Johnson0c716862023-04-13 17:18:19 +01001900 if not test_level8k:
1901 if testGen.args.oversize:
1902 # add some oversize argument values
1903 bigStride = 7
1904 bigKernel = 9
1905 strides.update(
1906 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001907 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001908 kernels.update(
1909 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
1910 )
1911 if max(shape) < 64:
1912 # padding must be less than the kernel size
1913 bigPadding = bigKernel - 1
1914 paddings.update(
1915 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
1916 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001917
Jeremy Johnson0c716862023-04-13 17:18:19 +01001918 # There are too many parameter combinations, so generate them sparsely,
1919 # very sparse for negative tests
1920 sparsity_factor = 2 if error_name else 500
1921 sparsity = (
1922 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
1923 )
1924 else:
1925 # We have already limited test output combinations for 8k tests
1926 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001927
James Ward8b390432022-08-12 20:48:56 +01001928 arg_str = (
1929 "acc{}_st{}_kern{}_pad{}"
1930 if accum_dtypes[0] is not None
1931 else "st{}_kern{}_pad{}"
1932 )
1933
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001934 def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]):
James Ward8b390432022-08-12 20:48:56 +01001935 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001936 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01001937
1938 # Support for larger values than 9 needs different delimiter
1939 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01001940 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01001941 delim.join([str(x) for x in stride]),
1942 delim.join([str(x) for x in kern]),
1943 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01001944 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001945 args_dict = {
1946 "stride": stride,
1947 "pad": pad,
1948 "kernel": kern,
1949 "dot_products": dot_products, # Ignored for error tests
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001950 "shape": shape,
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001951 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
1952 }
James Ward8b390432022-08-12 20:48:56 +01001953
1954 if accum is not None:
1955 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001956 args_dict["acc_type"] = accum
1957 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01001958
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001959 n = 0
James Ward8b390432022-08-12 20:48:56 +01001960 for a in accum_dtypes:
1961 for s in sorted(list(strides)):
1962 for p in sorted(list(paddings)):
1963 for k in sorted(list(kernels)):
1964 if error_name in [
1965 ErrorIf.StrideSmallerOne,
1966 ErrorIf.KernelSmallerOne,
1967 ErrorIf.PadSmallerZero,
1968 ErrorIf.PadLargerEqualKernel,
1969 ]:
1970 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
1971 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001972 )
James Ward8b390432022-08-12 20:48:56 +01001973 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001974 arg_list.append(
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00001975 get_arg_list_element(a, sNew, pNew, kNew, shape)
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001976 )
James Ward8b390432022-08-12 20:48:56 +01001977 elif (
1978 n % sparsity == 0
1979 # padding must not exceed the kernel size
1980 and p[0] < k[0]
1981 and p[1] < k[0]
1982 and p[2] < k[1]
1983 and p[3] < k[1]
1984 # the padded shape must exceed the kernel size
1985 and (shape[1] + p[0] + p[1]) > k[0]
1986 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001987 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001988 partial_h = shape[1] + p[0] + p[1] - k[0]
1989 partial_w = shape[2] + p[2] + p[3] - k[1]
1990 remainder_h = partial_h % s[0]
1991 remainder_w = partial_w % s[1]
1992 output_h = partial_h // s[0] + 1
1993 output_w = partial_w // s[1] + 1
1994 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01001995 if (
1996 # the parameters must produce integer exact output
1997 error_name != ErrorIf.PoolingOutputShapeNonInteger
1998 and remainder_h == 0
1999 and remainder_w == 0
2000 ) or (
2001 error_name == ErrorIf.PoolingOutputShapeNonInteger
2002 and (remainder_h != 0 or remainder_w != 0)
2003 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01002004 if (
2005 max_dim_size is not None
2006 and max(output_h, output_w) > max_dim_size
2007 ):
2008 # Test will consume too much memory - skip it
2009 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002010 # Dot products = N*OH*OW*C
2011 dp = gtu.product(
2012 (shape[0], output_h, output_w, shape[3])
2013 )
Jeremy Johnsonbfc53032023-11-01 11:29:56 +00002014 arg_list.append(
2015 get_arg_list_element(a, s, p, k, dp, shape)
2016 )
James Ward8b390432022-08-12 20:48:56 +01002017 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002018
Jeremy Johnsond41feb72023-10-12 16:03:15 +01002019 # Now add data generator types
2020 arg_list = TosaArgGen._add_data_generators(
2021 testGen,
2022 opName,
2023 dtype,
2024 arg_list,
2025 error_name,
2026 )
2027
2028 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002029 return arg_list
2030
2031 @staticmethod
2032 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
2033 arg_list = []
2034
2035 # Enumerate the output types here
2036 if error_name == ErrorIf.WrongOutputType:
2037 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
2038 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00002039 dtypeList = [
2040 DType.BOOL,
2041 DType.INT16,
2042 DType.INT32,
2043 DType.FP16,
2044 DType.BF16,
2045 DType.FP32,
2046 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002047 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00002048 dtypeList = [
2049 DType.BOOL,
2050 DType.INT8,
2051 DType.INT32,
2052 DType.FP16,
2053 DType.BF16,
2054 DType.FP32,
2055 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002056 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00002057 dtypeList = [
2058 DType.BOOL,
2059 DType.INT8,
2060 DType.INT16,
2061 DType.FP16,
2062 DType.BF16,
2063 DType.FP32,
2064 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002065 elif inDtype == DType.BOOL:
2066 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01002067 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00002068 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01002069 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00002070 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002071 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00002072 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002073 elif error_name == ErrorIf.WrongInputType:
2074 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002075 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002076 else:
2077 raise Exception("Unexpected input dtype: {}".format(inDtype))
2078
2079 for dtype in dtypeList:
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002080 arg_list.append(("out{}".format(testGen.typeStr(dtype)), [dtype]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002081
2082 return arg_list
2083
2084 @staticmethod
2085 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
2086 arg_list = []
2087
2088 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002089 for outDtype in [
2090 DType.UINT8,
2091 DType.INT8,
2092 DType.INT16,
2093 DType.INT32,
2094 DType.UINT16,
2095 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002096 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002097 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002098 and error_name == ErrorIf.OutputZeroPointNotZero
2099 ):
2100 continue
2101 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002102 outDtype != DType.UINT16
2103 and error_name == ErrorIf.U16OutputZeroPointNotValid
2104 ) or (
2105 inDtype != DType.UINT16
2106 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002107 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002108 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002109 continue
2110 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002111 inDtype == DType.UINT8
2112 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002113 and error_name != ErrorIf.WrongOutputType
2114 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002115 # The only output dtypes for UINT8 are INT8/INT16, skip all others
2116 continue
2117 if (
2118 inDtype not in [DType.INT8, DType.INT16]
2119 and outDtype == DType.UINT8
2120 and error_name != ErrorIf.WrongOutputType
2121 ):
2122 # The only input dtypes for UINT8 are INT8/INT16, skip all others
2123 continue
2124 if (
2125 inDtype == DType.UINT16
2126 and outDtype != DType.INT16
2127 and error_name != ErrorIf.WrongOutputType
2128 ):
2129 # The only output dtype for UINT16 is INT16, skip all others
2130 continue
2131 if (
2132 inDtype != DType.INT16
2133 and outDtype == DType.UINT16
2134 and error_name != ErrorIf.WrongOutputType
2135 ):
2136 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002137 continue
2138 if (
2139 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002140 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002141 ):
2142 continue
2143
2144 for scale32 in [False, True]:
2145 if error_name == ErrorIf.ScaleTrue and not scale32:
2146 continue
2147 elif error_name == ErrorIf.ScaleNotTrue and scale32:
2148 continue
2149 for double_round in [False, True]:
2150 if error_name == ErrorIf.ScaleNotTrue and not double_round:
2151 continue
2152 for per_channel in [False, True]:
2153
2154 if (
2155 inDtype == DType.INT48
2156 and scale32
2157 and error_name != ErrorIf.ScaleTrue
2158 ):
2159 # Illegal condition. Must be scale32=False
2160 continue
2161 if (
2162 double_round
2163 and not scale32
2164 and error_name != ErrorIf.ScaleNotTrue
2165 ):
2166 # Illegal condition. ERROR_IF(!scale32 && double_round)
2167 continue
2168
2169 arg_list.append(
2170 (
2171 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01002172 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002173 int(scale32),
2174 int(double_round),
2175 int(per_channel),
2176 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01002177 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002178 )
2179 )
2180
2181 return arg_list
2182
2183 @staticmethod
2184 def agMul(testGen, opName, shapeList, dtype, error_name=None):
2185 arg_list = []
2186
2187 if dtype is DType.INT32:
2188 for p in range(testGen.args.num_rand_permutations):
2189
2190 shift = testGen.randInt(0, 32)
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002191 arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002192 else:
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002193 arg_list.append(("perm0_shift0", {"shift": 0}))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002194
Jeremy Johnsona4d907e2023-10-26 13:53:14 +01002195 arg_list = TosaArgGen._add_data_generators(
2196 testGen,
2197 opName,
2198 dtype,
2199 arg_list,
2200 error_name,
2201 )
2202 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002203 return arg_list
2204
2205 @staticmethod
2206 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
2207 arg_list = []
2208
2209 arg_list.append(("roundTrue", [True]))
2210 arg_list.append(("roundFalse", [False]))
2211
2212 return arg_list
2213
Luke Hutton57287132023-02-06 14:54:18 +00002214 @staticmethod
2215 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2216 arg_list = []
2217
2218 arg_list.append(("inverseTrue", [True]))
2219 arg_list.append(("inverseFalse", [False]))
2220
2221 return arg_list
2222
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002223 # Helper function for reshape. Gets some factors of a larger number.
2224 @staticmethod
2225 def getFactors(val, start=1):
2226 factors = []
2227
2228 for i in range(start, int(np.sqrt(val)) + 1):
2229 if (val % i) == 0:
2230 factors.append(i)
2231
2232 return factors
2233
2234 @staticmethod
2235 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
2236 arg_list = []
2237
2238 origShape = shapeList[0]
2239
2240 totalElements = 1
2241 for s in origShape:
2242 totalElements *= s
2243
2244 # This code is NOT fast. Fortunately, the numbers are fairly small.
2245 factors = TosaArgGen.getFactors(totalElements)
2246
2247 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002248 # Rank from 1 to TOSA_TENSOR_MAX_RANK
2249 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002250 if len(factors) < newRank:
2251 continue
2252
2253 found = True
2254 # escape_counter breaks while loop if it continues on for too long
2255 escape_counter = 0
2256 while found:
2257 newShape = []
Jerry Ge264f7fa2023-04-21 22:49:57 +00002258 new_shape_inferred = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002259 # Generate newShape ensuring it isn't a duplicate
2260 remainingElements = totalElements
2261 shuffledFactors = testGen.rng.permutation(factors)
Jerry Ge264f7fa2023-04-21 22:49:57 +00002262 inferred_dim = testGen.rng.integers(1, newRank + 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002263 for i in range(1, newRank):
2264 # pick rank-1 factors
2265 newShape.append(shuffledFactors[0])
2266 remainingElements = remainingElements // shuffledFactors[0]
Jerry Ge264f7fa2023-04-21 22:49:57 +00002267 if i == inferred_dim:
2268 new_shape_inferred.append(-1)
2269 else:
2270 new_shape_inferred.append(shuffledFactors[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002271 shuffledFactors = testGen.rng.permutation(
2272 TosaArgGen.getFactors(remainingElements)
2273 )
2274 newShape.append(remainingElements)
Jerry Ge264f7fa2023-04-21 22:49:57 +00002275 if inferred_dim == newRank:
2276 new_shape_inferred.append(-1)
2277 else:
2278 new_shape_inferred.append(remainingElements)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002279
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002280 # Check for duplicates
2281 found = False
2282 for name, other_shape in arg_list:
2283 if other_shape[0] == newShape:
2284 found = True
2285 break
2286
2287 escape_counter += 1
2288 if escape_counter >= 100:
2289 break
2290
2291 if not found:
Jerry Ge264f7fa2023-04-21 22:49:57 +00002292 if error_name in [
2293 ErrorIf.ReshapeOutputSizeNonInteger,
2294 ErrorIf.ReshapeOutputSizeMultiInference,
2295 ]:
2296 if newRank < 2:
2297 # Need at least two dimensions
2298 continue
2299 # NOTE: Change inferred_dim starting offset from 1 to 0
2300 inferred_dim -= 1
2301 extra_dim = inferred_dim + testGen.rng.integers(1, newRank)
2302 extra_dim = extra_dim % newRank
2303 assert extra_dim != inferred_dim
2304 if error_name == ErrorIf.ReshapeOutputSizeNonInteger:
2305 elements = 1
2306 for i, dim_value in enumerate(new_shape_inferred):
2307 if i != inferred_dim and i != extra_dim:
2308 elements *= dim_value
2309 dim_value = new_shape_inferred[extra_dim]
2310 while totalElements % (elements * dim_value) == 0:
2311 dim_value += 1
2312 new_shape_inferred[extra_dim] = dim_value
2313 else:
2314 assert error_name == ErrorIf.ReshapeOutputSizeMultiInference
2315 new_shape_inferred[extra_dim] = -1
2316 else:
2317 arg_list.append(
2318 ("perm{}_rank{}_outdefined".format(p, newRank), [newShape])
2319 )
2320 if error_name != ErrorIf.TensorSizeInputOutputMismatch:
2321 arg_list.append(
2322 (
2323 "perm{}_rank{}_outinferred".format(p, newRank),
2324 [new_shape_inferred],
2325 )
2326 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002327
2328 return arg_list
2329
2330 @staticmethod
2331 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
2332 arg_list = []
2333
2334 ifm_shape = shapeList[0]
2335
2336 if error_name == ErrorIf.IndexOutsideBounds:
2337 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
2338 incorrect_small_index = range(-len(ifm_shape), 0)
2339 permutations = [p for p in itertools.permutations(incorrect_large_index)]
2340 permutations.extend(
2341 [p for p in itertools.permutations(incorrect_small_index)]
2342 )
2343 elif error_name == ErrorIf.IndexUsedTwice:
2344 # Create list with a duplicated index
2345 perm_range = list(range(len(ifm_shape)))
2346 index_choice = testGen.rng.choice(range(len(perm_range)))
2347 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
2348 permutations = [p for p in itertools.permutations(perm_range)]
2349
2350 else:
2351 # Get all permutations
2352 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
2353
2354 # Limit to possible permutations from shape dimension or argument setting
2355 limit = min(len(permutations), testGen.args.num_rand_permutations)
2356
2357 # Get random permutation generator that uses all permutations
2358 random_permutations = testGen.rng.permutation(permutations)
2359
2360 # Create list of required amount of permutations
2361 arg_list = [
2362 ("perm{}".format(p), [random_permutations[p].tolist()])
2363 for p in range(limit)
2364 ]
2365 return arg_list
2366
2367 @staticmethod
2368 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
2369 arg_list = []
2370
2371 ifm_shape = shapeList[0]
2372 rank = len(ifm_shape)
2373
2374 for p in range(testGen.args.num_rand_permutations):
2375 start = []
2376 size = []
2377
2378 valid = True
2379
2380 for i in range(rank):
2381 if ifm_shape[i] > 1:
2382 start.append(testGen.randInt(0, ifm_shape[i]))
2383 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
2384
2385 # Invalid slice size?
2386 if size[i] == 0:
2387 valid = False
2388 else:
2389 start.append(0)
2390 size.append(1)
2391
2392 if valid:
2393 # If ERROR_IF test required then incorrect start, size will be returned
2394 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
2395 testGen, error_name, ifm_shape, start, size
2396 )
2397 arg_list.append(("perm{}".format(p), [start, size]))
2398 return arg_list
2399
2400 @staticmethod
2401 def agTile(testGen, opName, shapeList, dtype, error_name=None):
2402 arg_list = []
2403
2404 ifm_shape = shapeList[0]
2405 rank = len(ifm_shape)
2406
2407 for p in range(testGen.args.num_rand_permutations):
2408
2409 # Pick a few random, but small multiple values
2410 # because otherwise this has a tendency to generate
2411 # enormous tensors
2412 multiples = []
2413 for i in range(rank):
2414 if ifm_shape[i] > 1000:
2415 # Multiple of 1 if ifm_shape dimension is large to reduce
2416 # tensor size
2417 multiples.append(1)
2418 elif max(ifm_shape) > 1000:
2419 multiples.append(2)
2420 else:
2421 multiples.append(testGen.randInt(1, 4))
2422 arg_list.append(("perm{}".format(p), [multiples]))
2423
2424 return arg_list
2425
2426 @staticmethod
2427 def agResize(testGen, opName, shapeList, dtype, error_name=None):
2428 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002429 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002430
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002431 def get_aspect_ratio_resize_params():
2432 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
2433 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
2434 invert = testGen.rng.choice((False, True))
2435 letterbox = testGen.rng.choice((False, True))
2436
2437 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
2438 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
2439 scale_y_d = scale_x_d = 1
2440 offset_x = offset_y = 0
2441
2442 if letterbox:
2443 max_border = scale_y_n
2444 border_y = testGen.randInt(low=0, high=max_border)
2445 border_x = 0
2446 else:
2447 # Pillarboxing
2448 border_y = 0
2449 max_border = scale_x_n
2450 border_x = testGen.randInt(low=0, high=max_border)
2451
2452 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2453 offset = (offset_y, offset_x)
2454 border = (border_y, border_x)
2455
2456 return scale, offset, border
2457
2458 def get_upscale_downscale_params():
2459 valid_params = False
2460 while not valid_params:
2461 upscale = testGen.rng.choice((False, True))
2462
2463 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
2464 origin_sampling = testGen.rng.choice((False, True))
2465
2466 if upscale:
2467 shift = testGen.randInt(low=1, high=4)
2468 scale_x_d = scale_y_d = 1
2469 scale_x_n = scale_y_n = (
2470 1 << shift if origin_sampling else 2 << shift
2471 )
2472 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
2473 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
2474 else:
2475 scale_x_n = 1
2476 scale_y_n = 1
2477
2478 # Return list of valid scale_*_d values (max value 4) given input dim shape
2479 def get_valid_denom(ifm_dim):
2480 return [x for x in range(1, 5) if ifm_dim % x == 1]
2481
2482 # Generate list of valid downscale values and choose one randomly
2483 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
2484 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
2485
2486 if not valid_scale_y_ds and not valid_scale_x_ds:
2487 # Bad parameters, skip
2488 continue
2489
2490 if not valid_scale_y_ds:
2491 scale_y_d = 1
2492 else:
2493 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
2494
2495 if not valid_scale_x_ds:
2496 scale_x_d = 1
2497 else:
2498 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
2499
2500 border_x = border_y = 0
2501 offset_y = testGen.randInt(0, 16 * scale_y_n)
2502 offset_x = testGen.randInt(0, 16 * scale_x_n)
2503 valid_params = True
2504
2505 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2506 offset = (offset_y, offset_x)
2507 border = (border_y, border_x)
2508 return scale, offset, border
2509
2510 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002511 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
2512 scale = scale_n / scale_d
2513 if scale > max_scale:
2514 factor = scale / max_scale
2515 new_scale_d = math.ceil(scale_d * factor)
2516 assert scale_n / new_scale_d <= max_scale
2517 scale_d = new_scale_d
2518 return scale_d
2519
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002520 # Scale
2521 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
2522 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
2523
2524 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
2525 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
2526
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002527 scale_y_d = fix_scale_to_max_scale(
2528 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2529 )
2530 scale_x_d = fix_scale_to_max_scale(
2531 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2532 )
2533
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002534 # Offsets and border within the scale
2535 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
2536 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
2537 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
2538 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
2539
2540 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2541 offset = (offset_y, offset_x)
2542 border = (border_y, border_x)
2543 return scale, offset, border
2544
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002545 def get_level_8k_params():
2546 # Create 64x scale - 64/1 to 2048/32
2547 scale_d = testGen.randInt(
2548 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
2549 )
2550 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
2551 # Create half to fifth scaling
2552 scale_d_alt = testGen.randInt(low=2, high=6)
2553 scale_n_alt = 1
2554 switch = testGen.rng.choice((False, True))
2555 if switch:
2556 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
2557 else:
2558 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
2559
2560 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
2561 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
2562 offset = (offset_y, offset_x)
2563 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
2564 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
2565 border = (border_y, border_x)
2566 return scale, offset, border
2567
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002568 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002569 # Exclude illegal {mode, type} configurations. Pick legal output types
2570 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
2571 outputDTypeList = [DType.INT8]
2572 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
2573 outputDTypeList = [DType.INT16]
2574 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
2575 outputDTypeList = [DType.INT32]
2576 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
2577 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01002578 elif dtype == DType.FP16:
2579 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01002580 elif dtype == DType.BF16:
2581 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002582 elif dtype == DType.FP32:
2583 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002584 elif error_name == ErrorIf.WrongInputType:
2585 # If an incorrect input type is used then we set a 'correct'
2586 # output type to avoid other errors
2587 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
2588 else:
2589 continue
2590
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002591 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
2592
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002593 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002594 perm = 0
2595 while perm < testGen.args.num_rand_permutations:
2596 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002597 if not testGen.args.level8k:
2598 _rnd_param_fn = testGen.rng.choice(
2599 (
2600 get_rand_params,
2601 get_upscale_downscale_params,
2602 get_aspect_ratio_resize_params,
2603 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002604 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002605 scale, offset, border = _rnd_param_fn()
2606 else:
2607 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002608
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002609 # Expand params for bounds-checking
2610 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
2611 (offset_y, offset_x) = offset
2612 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002613
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002614 # Make sure output dimensions OH and OW are integers
2615 partial_output_y = (
2616 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
2617 )
2618 partial_output_x = (
2619 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
2620 )
2621 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002622 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002623 if (
2624 partial_output_y % scale_y_d == 0
2625 and partial_output_x % scale_x_d == 0
2626 ):
2627 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002628 if perm > 0:
2629 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002630 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002631 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002632 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002633 while partial_output_y % scale_y_d != 0:
2634 scale_y_d -= 1
2635 while partial_output_x % scale_x_d != 0:
2636 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002637 # Make sure we are still within max scaling
2638 if (
2639 scale_y_n / scale_y_d
2640 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
2641 scale_x_n / scale_x_d
2642 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
2643 # Skip the test as it is using too large a scaling factor
2644 if perm > 0:
2645 perm += 1
2646 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002647
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002648 output_y = partial_output_y // scale_y_d + 1
2649 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002650
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002651 if (
2652 output_y >= testGen.args.max_resize_output_dim
2653 or output_x >= testGen.args.max_resize_output_dim
2654 ) and error_name is None:
2655 # Skip positive test if output dim will be too high
2656 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002657 if not testGen.args.level8k or perm > 0:
2658 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002659 continue
2660
2661 if (
2662 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01002663 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002664 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01002665 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002666 ):
2667 # Output dimensions out of scope
2668 if error_name is not None and perm > 0:
2669 # As long as we have one ERROR_IF test, don't worry
2670 # about creating all the other permutations
2671 perm += 1
2672 continue
2673
2674 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
2675 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01002676 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002677 and output_y - scale_y_d < 1
2678 )
2679 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01002680 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002681 and output_x - scale_x_d < 1
2682 )
2683 ):
2684 # Can't create a negative test with these params as it
2685 # will create invalid output size
2686 if perm > 0:
2687 perm += 1
2688 continue
2689
2690 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
2691 offset = [offset_y, offset_x]
2692 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002693
2694 # Common for all data types
2695 if error_name is not None:
2696 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002697 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002698 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002699 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002700 outputDTypeNew,
2701 ) = TosaErrorIfArgGen.eiResizeErrorIf(
2702 testGen,
2703 error_name,
2704 mode,
2705 dtype,
2706 shapeList,
2707 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002708 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002709 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002710 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002711 )
2712 else:
2713 outputDTypeNew = outputDType
2714
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002715 arg_to_append = (
2716 arg_str.format(
2717 "N" if mode == ResizeMode.NEAREST else "B",
2718 testGen.typeStr(outputDTypeNew),
2719 scale[0],
2720 scale[1],
2721 scale[2],
2722 scale[3],
2723 offset[0],
2724 offset[1],
2725 border[0],
2726 border[1],
2727 ),
2728 [
2729 mode,
2730 scale,
2731 offset,
2732 border,
2733 dtype,
2734 outputDTypeNew,
2735 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002736 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002737 if arg_to_append in arg_list:
2738 # Skip already generated test params
2739 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002740
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002741 # Valid permutation
2742 perm += 1
2743 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002744 return arg_list
2745
2746 @staticmethod
2747 def agTable(testGen, opName, shapeList, dtype, error_name=None):
2748 arg_list = []
2749
2750 if dtype == DType.INT8:
2751 table = np.int32(
2752 testGen.rng.integers(low=-128, high=128, size=[256])
2753 ).tolist()
2754 else: # INT16
2755 table = np.int32(
2756 testGen.rng.integers(low=-32768, high=32768, size=[513])
2757 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07002758 # Make sure all slopes are within REQUIRE min/max 16-bit int
2759 for idx in range(len(table) - 1):
2760 slope = table[idx + 1] - table[idx]
2761 # Alter the next table entry to force the slope to be ok
2762 if slope > 32767:
2763 table[idx + 1] -= slope - 32767
2764 if slope < -32768:
2765 table[idx + 1] -= slope + 32768
2766 slope = table[idx + 1] - table[idx]
2767 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002768 arg_list.append(
2769 (
2770 "",
2771 [table],
2772 )
2773 )
2774 return arg_list
2775
2776 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
2777 # CondIf generates the condition values here.
2778 # Convert to tensors in the build function, along with the
2779 # then and else blocks
2780 arg_list = []
2781
2782 for c in [False, True]:
2783 arg_list.append(("cond{}".format(int(c)), [c]))
2784
2785 return arg_list
2786
2787 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
2788 # While loop: 0 iterations, 1, more than 1
2789 arg_list = []
2790
2791 for iter in [0, 1, 4]:
2792 arg_list.append(("iter{}".format(iter), [iter]))
2793
2794 return arg_list