blob: f7837a0ea5128536f407b1776f7938f297c12b2f [file] [log] [blame]
Luke Hutton261b7b62023-01-10 14:50:31 +00001# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
Jeremy Johnson1271c442023-09-05 11:39:26 +01007import generator.tosa_utils as gtu
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01009from generator.tosa_error_if import ErrorIf
10from generator.tosa_error_if import TosaErrorIfArgGen
11from serializer.tosa_serializer import DTypeNames
12from tosa.DType import DType
13from tosa.Op import Op
14from tosa.ResizeMode import ResizeMode
15
16# DTypeNames, DType, Op and ResizeMode are convenience variables to the
17# flatc-generated types that should be enums, but aren't
18
19
20class TosaQuantGen:
21 """QuantizedInfo random generator helper functions.
22
23 Specify with 'qgen': in the operator defintion.
24 """
25
26 def __init__(self):
27 pass
28
29 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000030 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010031
32 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010033 if testGen.args.zeropoint is not None:
34 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010035 return testGen.randInt(-128, 128)
36 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010037 if testGen.args.zeropoint is not None:
38 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010039 return testGen.randInt(0, 256)
40 elif error_name in [
41 ErrorIf.InputZeroPointNotZero,
42 ErrorIf.WeightZeroPointNotZero,
43 ErrorIf.OutputZeroPointNotZero,
44 ]:
45 zero_point = testGen.randInt(-128, 128)
46 if zero_point == 0:
47 zero_point = 1
48 return zero_point
49 return 0
50
51 @staticmethod
52 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010053 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000054 qinfo = [
55 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
56 TosaQuantGen.getZeroPoint(testGen, dtype),
57 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010058 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000059 qinfo = [
60 TosaQuantGen.getZeroPoint(testGen, dtype),
61 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
62 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010063 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000064 qinfo = [
65 TosaQuantGen.getZeroPoint(testGen, dtype),
66 TosaQuantGen.getZeroPoint(testGen, dtype),
67 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010068 return qinfo
69
70 @staticmethod
71 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010072 if isinstance(dtype_or_dtypeList, list):
73 # a list of [input, weights, accumulator] dtypes
74 dtypeList = dtype_or_dtypeList
75 else:
76 # an int, [input, weights, accumulator] dtypes are the same
77 dtypeList = [dtype_or_dtypeList] * 3
78
79 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000080 qinfo = [
81 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
82 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
83 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010084 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000085 qinfo = [
86 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
87 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
88 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010089 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000090 qinfo = [
91 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
92 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
93 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010094 return qinfo
95
96 @staticmethod
97 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010098 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000099 qinfo = [
100 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
101 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
102 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100103 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000104 qinfo = [
105 TosaQuantGen.getZeroPoint(testGen, dtype),
106 TosaQuantGen.getZeroPoint(testGen, dtype),
107 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100108 return qinfo
109
110 @staticmethod
111 def computeMultiplierAndShift(scaleFp, scale32):
112 # Derived from computeMultiplierAndShiftTosaScale32
113 # Provide a floating-point scaling factor and the scale32 parameter
114 # to compute the multiplier and shift
115
116 if scale32:
117 scaleBits = 31
118 else:
119 scaleBits = 15
120
121 m, shift = math.frexp(scaleFp)
122
123 if scaleFp < 0.0:
124 m = -m
125
126 multiplier = round(m * (1 << scaleBits))
127 assert multiplier <= (1 << scaleBits)
128
129 if multiplier == (1 << scaleBits):
130 multiplier = multiplier // 2
131 shift = shift + 1
132
133 shift = (-shift) + scaleBits
134 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
135 # scaleFp, scaleBits, m, multiplier, shift))
136
137 # Adjust multiplier such that shift is in allowed value range.
138 if shift == 0:
139 multiplier = multiplier // 4
140 shift = shift + 2
141 elif shift == 1:
142 multiplier = multiplier // 2
143 shift = shift + 1
144 elif shift == 63:
145 multiplier = multiplier * 2
146 shift = shift - 1
147
148 assert multiplier <= (1 << scaleBits)
149 assert shift >= 2 and shift <= 62
150
151 return multiplier, shift
152
153
154class TosaTensorGen:
155 """Tensor generators create a shape list for the placeholder and const tensor
156 data operands for the operator.
157
158 The actual random data is generated separately for each test.
159 """
160
161 def __init__(self):
162 pass
163
164 @staticmethod
165 def tgBasic(testGen, opName, rank, error_name=None):
166 pl, const = opName["operands"]
167 shape = testGen.makeShape(rank)
168
169 # Constrict the overall size of the shape when creating ERROR_IF tests
170 if error_name:
171 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
172
173 shape_list = []
174 for i in range(pl + const):
175 shape_list.append(shape.copy())
176
Luke Huttona4e48ca2023-02-22 11:53:48 +0000177 # Generates an input rank mismatch for operators with more than one input
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100178 if error_name == ErrorIf.RankMismatch:
179 if rank == 1 and i != 1:
180 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
181 elif i != 1:
182 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
183
184 return shape_list
185
186 @staticmethod
187 def tgNHWC(testGen, opName, rank, error_name=None):
188 pl, const = opName["operands"]
189
190 if error_name != ErrorIf.WrongRank:
191 assert rank == 4
192
193 shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000194 shape = testGen.constrictBatchSize(shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100195
196 # Constrict the overall size of the shape when creating ERROR_IF tests
197 if error_name and error_name != ErrorIf.MaxDimExceeded:
198 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
199
200 shape_list = []
201 for i in range(pl + const):
202 shape_list.append(shape.copy())
203
204 return shape_list
205
206 @staticmethod
207 def tgScatter(testGen, opName, rank, error_name=None):
208 pl, const = opName["operands"]
209
210 assert pl == 2
211 assert const == 0
212 if error_name != ErrorIf.WrongRank:
213 assert rank == 3
214
215 values_in_shape = testGen.makeShape(rank)
216
217 # ignore max batch size if target shape is set
218 if testGen.args.max_batch_size and not testGen.args.target_shapes:
James Ward30124a82023-02-02 14:56:33 +0000219 values_in_shape[0] = min(values_in_shape[0], testGen.args.max_batch_size)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100220
221 W = testGen.randInt(
222 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
223 )
224 # Constrict W if one dimension is too large to keep tensor size reasonable
225 if max(values_in_shape) > 5000:
226 W = testGen.randInt(0, 16)
227
228 input_shape = [values_in_shape[0], W, values_in_shape[2]]
229
230 shape_list = []
231 shape_list.append(values_in_shape.copy())
232 shape_list.append(input_shape.copy())
233
234 return shape_list
235
236 @staticmethod
237 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
238 shape = testGen.makeShape(rank)
239
240 pl, const = op["operands"]
241
242 shape_list = []
243
244 # Choose one of the inputs to broadcast
245 # Note: Simplifies OutputShaper code if we don't change first shape for errors
246 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
Jerry Ge135c9552023-05-23 20:59:32 +0000247 fuzz_idx = testGen.randInt(0, rank)
248
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100249 for i in range(pl + const):
250 shape_bcast = shape.copy()
251
Jerry Ge135c9552023-05-23 20:59:32 +0000252 # To test broadcasting, the chosen fuzz index dimension should not be 1
253 if shape_bcast[fuzz_idx] == 1:
254 shape_bcast[fuzz_idx] += 1
255
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100256 # If the chosen input, pick a random index to broadcast
257 if i == bcast_idx:
Jerry Ge135c9552023-05-23 20:59:32 +0000258 if error_name == ErrorIf.RankMismatch:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100259 # Add one rank to the shape (or more for rank of 1)
260 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
261 shape_bcast = np.concatenate(
262 (shape_bcast, testGen.makeShape(extra_ranks))
263 )
264 if rank != 1:
265 # Either keep the extra rank, or remove it
266 new_len = testGen.rng.choice([-2, len(shape_bcast)])
267 shape_bcast = shape_bcast[:new_len]
Jerry Ge135c9552023-05-23 20:59:32 +0000268 elif error_name == ErrorIf.BroadcastShapesMismatch:
269 shape_bcast[fuzz_idx] += 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100270 else:
271 shape_bcast[fuzz_idx] = 1
272
273 shape_list.append(shape_bcast)
274
275 return shape_list
276
277 @staticmethod
278 def tgConv2D(testGen, op, rank, error_name=None):
279 pl, const = op["operands"]
280
281 if error_name != ErrorIf.WrongRank:
282 assert rank == 4
283
284 # IFM dimensions are NHWC
285 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000286 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100287
288 # Constrict the overall size of the shape when creating ERROR_IF tests
289 if error_name:
290 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
291 ifm_shape, max_dim=24, max_items=10000
292 )
293
294 # Get the filter height/width from the operator parameters
295 filter_hw = op["filter"]
296
297 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000298 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100299
300 # The filter dimensions are OHWI
301 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
302
303 # The bias is OC
304 bias_shape = np.asarray([ofm_depth])
305
306 return [ifm_shape, filter_shape, bias_shape]
307
308 @staticmethod
309 def tgConv3D(testGen, op, rank, error_name=None):
310 pl, const = op["operands"]
311
312 if error_name != ErrorIf.WrongRank:
313 assert rank == 5
314
315 # IFM dimensions are NDHWC
316 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000317 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100318
319 # Constrict the overall size of the shape when creating ERROR_IF tests
320 if error_name:
321 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
322 ifm_shape, max_dim=24, max_items=10000
323 )
324
325 # Get the filter depth/height/width from the operator parameters
326 filter_dhw = op["filter"]
327
328 # Generate a random OFM channel
James Ward30124a82023-02-02 14:56:33 +0000329 ofm_channel = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100330
331 # The filter dimensions are ODHWI
332 filter_shape = np.asarray(
333 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
334 )
335
336 # The bias is OC
337 bias_shape = np.asarray([ofm_channel])
338
339 return [ifm_shape, filter_shape, bias_shape]
340
341 @staticmethod
342 def tgTransposeConv2D(testGen, op, rank, error_name=None):
343 pl, const = op["operands"]
344
345 if error_name != ErrorIf.WrongRank:
346 assert rank == 4
347
348 # IFM dimensions are NHWC
349 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000350 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100351
352 # Constrict the overall size of the shape when creating ERROR_IF tests
353 if error_name:
354 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
355 ifm_shape, max_dim=24, max_items=10000
356 )
357
358 # Get the filter height/width from the operator parameters
359 filter_hw = op["filter"]
360
361 # Generate a random OFM depth
James Ward30124a82023-02-02 14:56:33 +0000362 ofm_depth = testGen.makeDimension()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100363
364 # The filter dimensions are OHWI
365 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
366
367 # The bias is OC
368 bias_shape = np.asarray([ofm_depth])
369
370 return [ifm_shape, filter_shape, bias_shape]
371
372 @staticmethod
373 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
374 pl, const = op["operands"]
375
376 if error_name != ErrorIf.WrongRank:
377 assert rank == 4
378 assert pl == 1 and const == 2
379
380 # IFM dimensions are NHWC
381 ifm_shape = testGen.makeShape(rank)
James Ward30124a82023-02-02 14:56:33 +0000382 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100383
384 # Constrict the overall size of the shape when creating ERROR_IF tests
385 if error_name:
386 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
387 ifm_shape, max_dim=24, max_items=10000
388 )
389
390 # Get the filter height/width from the operator parameters
391 # Filter is KH, HW, C, M
392 filter_hw = op["filter"]
393
394 # Generate a random OFM depth, but don't let it get too big because
395 # the output depth is M * C
396 filter_m = (
James Ward30124a82023-02-02 14:56:33 +0000397 testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100398 ) + 1
399
400 # The filter dimensions are HWCM
401 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
402
403 # The bias is M * C
404 bias_shape = np.asarray([ifm_shape[3] * filter_m])
405
406 return [ifm_shape, filter_shape, bias_shape]
407
408 @staticmethod
Luke Hutton57287132023-02-06 14:54:18 +0000409 def tgFFT2d(testGen, op, rank, error_name=None):
410 pl, const = op["operands"]
411
412 if error_name != ErrorIf.WrongRank:
413 assert rank == 3
414 assert pl == 2 and const == 0
415
416 # IFM dimensions are NHW
417 ifm_shape = testGen.makeShape(rank)
418
419 # Select nearest lower power of two from input height and width
420 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
421 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
422
423 # Constrict the overall size of the shape when creating ERROR_IF tests
424 if error_name:
425 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
426
427 # Generate an invalid kernel that is not a power of two
428 if error_name == ErrorIf.KernelNotPowerOfTwo:
429 inc_h = 2 if ifm_shape[1] == 1 else 1
430 inc_w = 2 if ifm_shape[2] == 1 else 1
431 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
432 selected_inc = testGen.rng.choice(inc_choices)
433 ifm_shape[1] += selected_inc[0]
434 ifm_shape[2] += selected_inc[1]
435
436 ifm_shape = testGen.constrictBatchSize(ifm_shape)
437
438 ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
439 if error_name == ErrorIf.FFTInputShapeMismatch:
440 modify_shape = testGen.rng.choice([0, 1])
441 # Only modify kernel (H, W)
442 modify_dim = testGen.rng.choice([1, 2])
443 ifm_shapes[modify_shape][modify_dim] *= 2
444
445 return [ifm_shapes[0], ifm_shapes[1]]
446
447 @staticmethod
Luke Hutton261b7b62023-01-10 14:50:31 +0000448 def tgRFFT2d(testGen, op, rank, error_name=None):
449 pl, const = op["operands"]
450
451 if error_name != ErrorIf.WrongRank:
452 assert rank == 3
453 assert pl == 1 and const == 0
454
455 # IFM dimensions are NHW
456 ifm_shape = testGen.makeShape(rank)
457
458 # Select nearest lower power of two from input height and width
459 ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
460 ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
461
462 # Constrict the overall size of the shape when creating ERROR_IF tests
463 if error_name:
464 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
465
466 # Generate an invalid kernel that is not a power of two
467 if error_name == ErrorIf.KernelNotPowerOfTwo:
468 # We must increment by 2 if current size is 1
469 inc_h = 2 if ifm_shape[1] == 1 else 1
470 inc_w = 2 if ifm_shape[2] == 1 else 1
471 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
472 selected_inc = testGen.rng.choice(inc_choices)
473 ifm_shape[1] += selected_inc[0]
474 ifm_shape[2] += selected_inc[1]
475
James Ward30124a82023-02-02 14:56:33 +0000476 ifm_shape = testGen.constrictBatchSize(ifm_shape)
Luke Hutton261b7b62023-01-10 14:50:31 +0000477
478 return [ifm_shape]
479
480 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100481 def tgFullyConnected(testGen, op, rank, error_name=None):
482 pl, const = op["operands"]
483
484 if error_name != ErrorIf.WrongRank:
485 assert rank == 2
486
487 input_shape = testGen.makeShape(rank)
488
489 # Constrict the overall size of the shape when creating ERROR_IF tests
490 if error_name:
491 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
492
493 filter_oc = testGen.rng.integers(
494 low=testGen.args.tensor_shape_range[0],
495 high=testGen.args.tensor_shape_range[1],
496 size=1,
497 )[0]
498 filter_shape = np.asarray([filter_oc, input_shape[1]])
499
500 bias_shape = np.asarray([filter_oc])
501
502 return [input_shape, filter_shape, bias_shape]
503
504 @staticmethod
505 def tgMatmul(testGen, op, rank, error_name=None):
506 pl, const = op["operands"]
507
508 if error_name != ErrorIf.WrongRank:
509 assert rank == 3
510 assert pl == 2 and const == 0
511
512 a_shape = testGen.makeShape(rank)
513
514 # Constrict the overall size of the shape when creating ERROR_IF tests
515 if error_name:
516 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
517
518 # Get a random number for b_oc even if target shape is defined
519 b_oc = np.int32(
520 testGen.rng.integers(
521 low=testGen.args.tensor_shape_range[0],
522 high=testGen.args.tensor_shape_range[1],
523 size=1,
524 )
525 )[0]
526 # If N or H is large let b_oc be 1 to reduce output tensor size
527 if max(a_shape) > 1000:
528 b_oc = 1
529
530 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
531 return [a_shape, b_shape]
532
533 @staticmethod
534 def tgConcat(testGen, opName, rank, error_name=None):
535 pl, const = opName["operands"]
536 shape = testGen.makeShape(rank)
537
538 # Create extra tensors to concat.
539 # Take into account value of pl when getting maximum number of concats
540 num_tensors = testGen.randInt(0, 4)
541 shape_list = []
542 for i in range(pl + const + num_tensors):
543 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
544 remove = testGen.rng.choice([True, False])
545 wrongShape = shape.copy()
546
547 if remove and len(shape) > 1:
548 wrongShape = wrongShape[1:]
549 else:
550 wrongShape = list(wrongShape)
551 wrongShape.append(testGen.rng.integers(1, 10))
552
553 shape_list.append(wrongShape)
554 else:
555 shape_list.append(shape.copy())
556
557 return shape_list
558
559 @staticmethod
560 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
561 if error_name in [
562 ErrorIf.AxisSmallerZero,
563 ErrorIf.AxisLargerRank,
564 ErrorIf.ConcatInputRankMismatch,
565 ]:
566 return shapeList
567
568 # Split concat shape along axis to allow for multiple const inputs
569 # without making too many large tensors
570 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
571 # If axis can't be split we still need to invalidate other dimensions
572 if error_name == ErrorIf.ConcatInputDimMismatch:
573 for shape in shapeList[1:]:
574 # Negative test shapeLists are created individually for each test,
575 # so no need to copy the shape before altering it.
576 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
577 return shapeList
578
579 # Create copy of shape we are going to split (so we don't alter shapeList)
580 shape = shapeList[0].copy()
581 # Add original shape as first input
582 new_shapeList = [shape.copy()]
583 length_on_axis = shape[axis]
584 remaining_length = length_on_axis
585 for i in range(len(shapeList) - 2):
586 # Calculate split on axis and remaining value
587 split_shape_val = int(shape[axis] / 2)
588 remaining_length = remaining_length - split_shape_val
589
590 # Append new shape, and set remaining shape
591 shape[axis] = split_shape_val
592 new_shapeList.append(shape.copy())
593
594 # invalidate dimensions
595 if error_name == ErrorIf.ConcatInputDimMismatch:
596 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
597 else:
598 shape[axis] = remaining_length
599
600 if i == len(shapeList) - 3:
601 new_shapeList.append(shape.copy())
602
603 return new_shapeList
604
605
606class TosaTensorValuesGen:
Jeremy Johnson1271c442023-09-05 11:39:26 +0100607 """Tensor Value generators create the random data for each tensor in each test."""
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100608
609 def __init__(self):
610 pass
611
Jeremy Johnson1271c442023-09-05 11:39:26 +0100612 class TVGInfo:
613 """Enhanced tensor values information including data gen dict."""
614
615 def __init__(self, tensorList, dataGenDict):
616 self.tensorList = tensorList
617 self.dataGenDict = dataGenDict
618
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100619 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000620 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100621 pCount, cCount = op["operands"]
622
623 tens = []
624 tens.extend(
625 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
626 )
627 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
628
629 return tens
630
631 @staticmethod
Jeremy Johnson1271c442023-09-05 11:39:26 +0100632 def tvgLazyGenDefault(
633 testGen, opName, dtypeList, shapeList, argsDict, error_name=None
634 ):
635 # Variable inputs versus constants
636 pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
637
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100638 if (
639 error_name is not None
640 or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
641 or opName in ("avg_pool2d",)
642 ):
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100643 # Fall back to original path when dealing with unsupported types
644
645 # First turn off lazy data gen so we always produce data
646 lazy_data_gen = testGen.args.lazy_data_gen
Jeremy Johnson1271c442023-09-05 11:39:26 +0100647 testGen.args.lazy_data_gen = False
648
Jeremy Johnson1271c442023-09-05 11:39:26 +0100649 tens_ser_list = TosaTensorValuesGen.tvgDefault(
650 testGen,
651 testGen.TOSA_OP_LIST[opName],
652 dtypeList,
653 shapeList,
654 [],
655 error_name,
656 )
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100657 # Restore lazy data gen setting
658 testGen.args.lazy_data_gen = lazy_data_gen
Jeremy Johnson1271c442023-09-05 11:39:26 +0100659 return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
660
661 # Create data generator meta-data
662 dg_type = argsDict["dg_type"]
663 dg_tens_meta = {}
664 tens_ser_list = []
665 for idx, shape in enumerate(shapeList):
666
667 tens_meta = {}
668 tens_meta["generator"] = gtu.DataGenType(dg_type).name
669 tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
670 tens_meta["shape"] = [int(i) for i in shape]
671 tens_meta["input_pos"] = idx
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100672 tens_meta["op"] = opName.upper()
Jeremy Johnson1271c442023-09-05 11:39:26 +0100673
674 if idx < pCount:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100675 tens_meta["input_type"] = "VARIABLE"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100676 tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], None)
677 else:
Jeremy Johnsonfc5e34e2023-10-24 14:45:12 +0100678 tens_meta["input_type"] = "CONSTANT"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100679 tens = testGen.ser.addConst(shape, dtypeList[idx], None)
680 tens_ser_list.append(tens)
681
682 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
683 info = {}
684 # TODO - generate seed for this generator based on test
Jeremy Johnsond41feb72023-10-12 16:03:15 +0100685 info["rng_seed"] = 42
Jeremy Johnson1271c442023-09-05 11:39:26 +0100686 info["range"] = [
687 str(v)
688 for v in testGen.getDTypeRange(dtypeList[idx], high_inclusive=True)
689 ]
690 tens_meta["pseudo_random_info"] = info
691 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
692 info = {}
693 info["s"] = argsDict["s"]
694 info["ks"] = argsDict["ks"]
695 for key in gtu.DG_DOT_PRODUCT_OPTIONAL_INFO:
696 if key in argsDict:
697 if key.endswith("_type"):
698 info[key] = gtu.DTYPE_ATTRIBUTES[argsDict[key]]["json"]
699 else:
700 info[key] = argsDict[key]
701 tens_meta["dot_product_info"] = info
702 else:
703 # TODO - other data gen type
704 assert False, "TODO: support other data gen types"
705 dg_tens_meta[tens.name] = tens_meta
706
707 tens_data = {
708 "version": "0.1",
709 "tensors": dg_tens_meta,
710 }
711 return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
712
713 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000714 def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100715 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100716 pCount, cCount = op["operands"]
717 assert (
718 pCount == 1 and cCount == 0
719 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100720 # Must create tensors with values within accumulator (int32) negatable
721 # range
722 max_val = (1 << 31) - 1
723 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100724 arr = np.int32(
725 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
726 )
727 placeholders = []
728 placeholders.append(
729 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
730 )
731 return placeholders
732 else:
733 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000734 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100735 )
736
737 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000738 def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100739 if dtypeList[0] == DType.INT32 and error_name is None:
740 # Make sure the operation does not cause value saturation - where
741 # the number wraps due to limited number of bits to store the answer
742 pCount, cCount = op["operands"]
743 assert (
744 pCount == 2 and cCount == 0
745 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
746 placeholders = []
747 add = op["op"] == Op.ADD
748 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
749 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
750 if add:
751 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
752 else:
753 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
754
755 # Work out the saturation limits
756 max_i32 = (1 << 31) - 1
757 min_i32 = -(1 << 31)
758 max_arr = np.full(shapeList[1], max_i32)
759 min_arr = np.full(shapeList[1], min_i32)
760
761 # Find how much values exceed the maximum/minimums
762 sat_max_arr = np.maximum(res_arr - max_arr, 0)
763 sat_min_arr = np.minimum(res_arr - min_arr, 0)
764
765 if not add:
766 # Swap saturation values and negate values as we need to perform opposite operations
767 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
768
769 # Create new array of unsaturated values by clipping values as needed
770 b_unsat_arr = b_arr
771 if (sat_max_arr != 0).any():
772 # Clip values that cause saturation
773 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
774 # Reduce axes in unsaturated tensor to match original tensor
775 for axis, dim in enumerate(b_arr.shape):
776 if dim != b_unsat_arr.shape[axis]:
777 assert (
778 dim == 1
779 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
780 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
781
782 if (sat_min_arr != 0).any():
783 # Clip values that cause saturation
784 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
785 # Reduce axes in unsaturated tensor to match original tensor
786 for axis, dim in enumerate(b_arr.shape):
787 if dim != b_unsat_arr.shape[axis]:
788 assert (
789 dim == 1
790 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
791 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
792
793 placeholders.append(
794 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
795 )
796 placeholders.append(
797 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
798 )
799
800 return placeholders
801 else:
802 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000803 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100804 )
805
806 @staticmethod
807 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000808 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100809 ):
810 if dtypeList[0] in (
811 DType.INT32,
812 DType.INT16,
813 DType.INT8,
814 ):
815 # Limit input tensors with cond_if_binary or while_loop to stop
816 # saturation of add/sub ops with int32 and keep all logical shift
817 # values between 0 to 31 for int16 or int8
818 pCount, cCount = op["operands"]
819 pRemain = pCount
820 placeholders = []
821 for idx, shape in enumerate(shapeList[:]):
822 if dtypeList[0] == DType.INT32:
823 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
824 else:
825 arr = np.int32(
826 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
827 )
828 if pRemain > 0:
829 placeholders.append(
830 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
831 )
832 pRemain -= 1
833 else:
834 placeholders.append(
835 testGen.ser.addConst(shape, dtypeList[idx], arr)
836 )
837
838 return placeholders
839 else:
840 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000841 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100842 )
843
844 @staticmethod
845 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000846 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100847 ):
848 pCount, cCount = op["operands"]
849 # Force value of operand[1] to be within [0, num_bits]
850 assert (
851 pCount == 2 and cCount == 0
852 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
853
854 placeholders = []
855 for idx, shape in enumerate(shapeList[:]):
856 if idx == 1:
857 if dtypeList[idx] == DType.INT8:
858 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
859 elif dtypeList[idx] == DType.INT16:
860 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
861 elif dtypeList[idx] == DType.INT32:
862 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
863 elif error_name == ErrorIf.WrongInputType:
864 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
865 else:
866 raise Exception("OpArithmeticRightShift: invalid input dtype")
867 else:
868 arr = testGen.getRandTensor(shape, dtypeList[idx])
869 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
870
871 return placeholders
872
873 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000874 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100875 # Set datatype of condition tensor to boolean
876 dtypeList[0] = DType.BOOL
877
878 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000879 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100880 )
881
882 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000883 def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100884 if error_name is None:
885 pCount, cCount = op["operands"]
886 assert (
887 pCount == 2 and cCount == 0
888 ), "Op.INTDIV must have 2 placeholders, 0 consts"
889
890 placeholders = []
891
892 # Two invalid cases for Op.INTDIV:
893 # 1. divisor == 0
894 # 2. dividend == -(1<<31) and divisor == -1
895 while True:
896 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
897 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
898
899 if (divisor_arr == 0).any():
900 continue
901
902 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
903 continue
904
905 break
906
907 placeholders.append(
908 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
909 )
910 placeholders.append(
911 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
912 )
913
914 return placeholders
915 else:
916 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000917 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100918 )
919
920 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000921 def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100922 if error_name is None:
923 pCount, cCount = op["operands"]
924 assert (
925 pCount == 2 and cCount == 0
926 ), "Op.MUL must have 2 placeholders, 0 consts"
927
928 tens = []
James Ward24dbc422022-10-19 12:20:31 +0100929 if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100930 tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
931 else:
932 placeholders = []
933
934 # Make sure multiply result in int32 range
935 shift = testArgs[0]
936 if dtypeList[0] == DType.INT8:
937 num_bits = 8
938 elif dtypeList[0] == DType.INT16:
939 num_bits = 16
940 elif dtypeList[0] == DType.INT32:
941 num_bits = 32
942 elif error_name == ErrorIf.WrongInputType:
943 num_bits = 8
944 else:
945 raise Exception("OpMul: invalid input dtype")
946
947 for idx, shape in enumerate(shapeList[:]):
948 low = -(2 ** (num_bits - 1))
949 high = (2 ** (num_bits - 1)) - 1
950
951 a_arr = np.int32(
952 testGen.rng.integers(low=low, high=high, size=shapeList[0])
953 )
954 b_arr = np.int32(
955 testGen.rng.integers(low=low, high=high, size=shapeList[1])
956 )
957
958 i = 0
959 while True:
960
961 a_arr_64 = a_arr.astype(np.int64)
962 b_arr_64 = b_arr.astype(np.int64)
963
964 if shift > 0:
965 rounding = 1 << (shift - 1)
966 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
967 else:
968 result_arr = a_arr_64 * b_arr_64
969
970 if (result_arr > -(2**31)).all() and (
971 result_arr <= ((2**31) - 1)
972 ).all():
973 break
974
975 i = i + 1
976 a_arr = a_arr // 2
977 b_arr = b_arr // 2
978
979 placeholders.append(
980 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
981 )
982 placeholders.append(
983 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
984 )
985
986 tens.extend(placeholders)
987
988 return tens
989 else:
990 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000991 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100992 )
993
994 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000995 def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100996 count = len(shapeList) - testGen.args.num_const_inputs_concat
997 if count < 1:
998 count = 1
999 if testGen.args.num_const_inputs_concat == 0:
1000 count = len(shapeList)
1001
1002 # Ensure axis is an int
1003 testArgs[0] = int(testArgs[0])
1004
1005 shapeList = TosaTensorGen.tgConcatConstInput(
1006 testGen, shapeList, testArgs[0], error_name
1007 )
1008
1009 tens = []
1010 tens.extend(
1011 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
1012 )
1013 tens.extend(testGen.buildConstTensors(shapeList[count:], dtypeList[count:]))
1014
1015 return tens
1016
1017 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001018 def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001019 pCount, cCount = op["operands"]
1020 assert (
1021 pCount == 2 and cCount == 0
1022 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
1023 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1024 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
1025 placeholders = []
1026 placeholders.append(
1027 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1028 )
1029 placeholders.append(
1030 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
1031 )
1032
1033 return placeholders
1034
1035 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001036 def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001037 if error_name is None:
1038 pCount, cCount = op["operands"]
1039 assert (
1040 pCount == 2 and cCount == 0
1041 ), "Op.EQUAL must have 2 placeholders, 0 consts"
1042 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
1043 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
1044 # Using random numbers means that it will be very unlikely that
1045 # there are any matching (equal) values, therefore force that
1046 # there are twice the number of matching values as the tensor rank
1047 for num in range(0, len(shapeList[0]) * 2):
1048 a_index = []
1049 b_index = []
1050 # Choose an index in each axis for the whole shape
1051 for axis in range(0, len(shapeList[0])):
1052 # Index can be up to the largest dimension in both shapes
1053 index = np.int32(
1054 testGen.rng.integers(
1055 0, max(shapeList[0][axis], shapeList[1][axis])
1056 )
1057 )
1058 # Reduce the index down to a shape's dim for broadcasting
1059 a_index.append(min(shapeList[0][axis] - 1, index))
1060 b_index.append(min(shapeList[1][axis] - 1, index))
1061
1062 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
1063
1064 placeholders = []
1065 placeholders.append(
1066 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
1067 )
1068 placeholders.append(
1069 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
1070 )
1071 return placeholders
1072 else:
1073 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001074 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001075 )
1076
1077 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001078 def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001079 if dtypeList[0] == DType.INT32:
1080 pCount, cCount = op["operands"]
1081 assert (
1082 pCount == 1 and cCount == 0
1083 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
1084 # Limit values so that the sum cannot exceed the range of an int32 during
1085 # summation of any axis
1086 range_val = int((1 << 31) / max(shapeList[0]))
1087 values_arr = np.int32(
1088 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
1089 )
1090 placeholders = []
1091 placeholders.append(
1092 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
1093 )
1094 return placeholders
1095 else:
1096 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +00001097 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001098 )
1099
1100
1101class TosaArgGen:
1102 """Argument generators create exhaustive or random lists of attributes for
1103 operators that take attributes or other parameters.
1104
1105 The return value is a list of (descriptive_name, [arglist]) tuples where
1106 the descriptive_name is appended to the test name and the arglist is expanded
1107 as arguments to the operator build function.
1108 """
1109
1110 def __init__(self):
1111 pass
1112
1113 @staticmethod
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001114 def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001115 """Add extra tests for each type of data generator for this op."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001116 if (
1117 error_name is None
1118 and "data_gen" in testGen.TOSA_OP_LIST[opName]
1119 and gtu.dtypeIsSupportedByCompliance(dtype)
1120 ):
Jeremy Johnson1271c442023-09-05 11:39:26 +01001121 if dtype in [DType.FP16, DType.FP32, DType.BF16]:
1122 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
1123 else:
1124 dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
1125 else:
1126 # Error test or No data generator types listed - assume random
1127 dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
1128
1129 # Expand arg list with other data generator types
1130 new_arg_list = []
1131 for dg_type in dataGenTypesList:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001132 for arg_str, args_dict in arg_list:
1133 args_dict["dg_type"] = dg_type
Jeremy Johnson1271c442023-09-05 11:39:26 +01001134 if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
1135 # Default test
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001136 new_arg_list.append((arg_str, args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001137
1138 elif dg_type == gtu.DataGenType.DOT_PRODUCT:
1139 # Extra tests for each dot product test set
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001140 dot_products = args_dict["dot_products"]
Jeremy Johnson1271c442023-09-05 11:39:26 +01001141 if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
1142 print(
Jeremy Johnson51779fd2023-09-12 10:27:43 +01001143 f"Skipping {opName} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
Jeremy Johnson1271c442023-09-05 11:39:26 +01001144 )
1145 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001146 # KS is required by all dot product generators
1147 assert "ks" in args_dict
Jeremy Johnson1271c442023-09-05 11:39:26 +01001148
1149 for s in testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS:
1150 new_arg_str = f"{arg_str}_s{s}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001151 new_args_dict = args_dict.copy()
1152 new_args_dict["s"] = s
1153 new_arg_list.append((new_arg_str, new_args_dict))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001154
1155 return new_arg_list
1156
1157 @staticmethod
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001158 def agNone(testGen, opName, shapeList, dtype, error_name=None):
1159 """A trivial argument generator for operators that don't take any
1160 non-tensor arguments"""
1161 return [("", [])]
1162
1163 @staticmethod
1164 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
1165 """Build the axis argument for operators that take a single axis"""
1166 axes = []
1167 shape = shapeList[0]
1168
1169 if error_name == ErrorIf.AxisSmallerZero:
1170 small_axis = testGen.rng.integers(-5, 0)
1171 axes.append(("axis{}".format(small_axis), [small_axis]))
1172 elif error_name == ErrorIf.AxisLargerRank:
1173 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
1174 axes.append(("axis{}".format(large_axis), [large_axis]))
1175 else:
1176 for a in range(0, len(shape)):
1177 axes.append(("axis{}".format(a), [a]))
1178
1179 return axes
1180
1181 @staticmethod
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001182 def _calculate_sparsity(num_tests, sparsity_factor):
1183 sparsity = num_tests // sparsity_factor + 1
1184 # If there are only a small number of tests, just select them all
1185 if sparsity < 13:
1186 sparsity = 1
1187 # To get a variety of parameter combinations sparsity should not be a
1188 # multiple of 2, 3 or 5
1189 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1190 sparsity += 1
1191 return sparsity
1192
1193 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001194 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001195 # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001196 arg_list = []
1197
Jeremy Johnson0c716862023-04-13 17:18:19 +01001198 if testGen.args.level8k and error_name is not None:
1199 # Don't produce negative large tests
1200 return arg_list
1201
1202 # Shape: Batches, (Depth), Height, Width, Channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001203 ifm_shape = shapeList[0]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001204 # Shape: (OFM channels), (KD), KH, KW, IFM channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001205 filter_shape = shapeList[1]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001206
Jeremy Johnson1271c442023-09-05 11:39:26 +01001207 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001208
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001209 # Check the rank
Jeremy Johnson0c716862023-04-13 17:18:19 +01001210 conv3d = opName.startswith("conv3d")
1211 rank = 5 if conv3d else 4
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001212 if error_name != ErrorIf.WrongRank:
1213 assert len(ifm_shape) == rank
1214 assert len(filter_shape) == rank
1215
Jeremy Johnson0c716862023-04-13 17:18:19 +01001216 # kernel rank omits channels
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001217 k_rank = rank - 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001218 k_pos = 0 if opName.startswith("depthwise") else 1
1219 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001220
Jeremy Johnson0c716862023-04-13 17:18:19 +01001221 if not testGen.args.level8k:
1222 # Generate comprehensive argument lists
1223 # - except for named errors, which use specific invalid value(s)
1224 if error_name == ErrorIf.PadSmallerZero:
1225 p_vals = [testGen.rng.choice(range(-5, 0))]
1226 else:
1227 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1228 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1229 if error_name == ErrorIf.StrideSmallerOne:
1230 # Can't use stride=0, as it is used to derive output shape, as a divisor
1231 s_vals = [testGen.rng.choice(range(-5, 0))]
1232 else:
1233 # Stride must be greater than 1 to force non-integer error
1234 startStride = (
1235 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001236 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001237 s_vals = [
1238 x for x in range(startStride, testGen.args.max_conv_stride + 1)
1239 ]
1240 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1241 if error_name == ErrorIf.DilationSmallerOne:
1242 d_vals = [testGen.rng.choice(range(-5, 1))]
1243 else:
1244 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1245 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001246
Jeremy Johnson0c716862023-04-13 17:18:19 +01001247 if not error_name and testGen.args.oversize:
1248 # add some oversize argument values
1249 if max(ifm_shape) < 64:
1250 bigPadding = 9
1251 paddings.update(
1252 {
1253 x
1254 for x in itertools.product(
1255 *([[0, bigPadding]] * (k_rank * 2))
1256 )
1257 }
1258 )
1259 bigStride = 8
1260 strides.update(
1261 {x for x in itertools.product(*([[1, bigStride]] * k_rank))}
1262 )
1263 bigDilation = 7
1264 dilations.update(
1265 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1266 )
1267 max_dim_size = None
1268
1269 # There are too many parameter combinations, so generate them sparsely,
1270 # very sparse for negative tests
1271 sparsity_factor = 2 if error_name else 120
1272 sparsity = TosaArgGen._calculate_sparsity(
1273 len(paddings) * len(strides) * len(dilations), sparsity_factor
1274 )
1275 else:
1276 # Only test 8k levels boundaries
1277 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1278 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1279 bigPadding = bigKernel
1280
1281 dilation_shape = [1] * k_rank
1282 pad_shape = [0] * k_rank * 2
1283 if conv3d:
1284 # Small stride apart from for big kernel (see below) to keep
1285 # tensor size/calculation small
1286 stride_shape = [1] * k_rank
1287 for idx in range(k_rank):
1288 pad_offset = idx * 2
1289 if k_shape[idx] == bigKernel:
1290 # Padding shape needs to account for tensor shape
1291 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1292 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1293 # Big stride to reduce output size
1294 stride_shape[idx] = bigKernel
1295 else:
1296 # Account for kernel size
1297 pad_shape[pad_offset] = k_shape[idx] - 1
1298 else:
1299 # Always have a large stride with extra padding and dilation to keep
1300 # tensor calculation reasonable
1301 stride_shape = [bigKernel] * k_rank
1302 for idx in range(k_rank):
1303 # Dilation shape must account for kernel size
1304 dilation_shape[idx] = bigKernel // k_shape[idx]
1305 # Padding shape needs to accommodate tensor/kernel & dilation
1306 pad_offset = idx * 2
1307 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1]
1308 pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1
1309
1310 strides = {tuple(stride_shape)}
1311 dilations = {tuple(dilation_shape)}
1312 paddings = {tuple(pad_shape)}
1313 # Create a limit for the output dimensions size
1314 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1315
1316 # Currently allow all combinations that are reasonable size
1317 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001318
1319 n = 0
1320 for s in sorted(list(strides)):
1321 for p in sorted(list(paddings)):
1322 for d in sorted(list(dilations)):
1323 if (
1324 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001325 # the padded shape must exceed the dilation * kernel to get a positive
1326 # sized output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001327 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
1328 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
Jeremy Johnson93d43902022-09-27 12:26:14 +01001329 and (
1330 k_rank < 3
Jeremy Johnson0c716862023-04-13 17:18:19 +01001331 or (
1332 (ifm_shape[3] - 1 + p[4] + p[5])
1333 > d[2] * (k_shape[2] - 1)
1334 )
Jeremy Johnson93d43902022-09-27 12:26:14 +01001335 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001336 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001337 remainders = []
Jeremy Johnson0c716862023-04-13 17:18:19 +01001338 outputs = []
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001339 for index in range(k_rank):
1340 pad_offset = index * 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001341 partial = (
1342 ifm_shape[index + 1]
1343 - 1
1344 + p[pad_offset]
1345 + p[pad_offset + 1]
1346 - (k_shape[index] - 1) * d[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001347 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001348 remainders.append(partial % s[index])
1349 outputs.append((partial // s[index]) + 1)
1350
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001351 if (
1352 # the parameters must produce integer exact output
1353 error_name != ErrorIf.ConvOutputShapeNonInteger
1354 and max(remainders) == 0
1355 ) or (
1356 error_name == ErrorIf.ConvOutputShapeNonInteger
1357 and max(remainders) > 0
1358 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001359 if (
1360 max_dim_size is not None
1361 and max(outputs) >= max_dim_size
1362 ):
1363 # Test will consume too much memory - skip it
1364 continue
1365
1366 # Support for larger values than 9 needs different delimiter
1367 delim = "" if max(s + p + d) <= 9 else "x"
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001368 arg_list.append(
1369 (
James Ward8b390432022-08-12 20:48:56 +01001370 "acc{}_st{}_pad{}_dilat{}".format(
1371 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001372 delim.join([str(x) for x in s]),
1373 delim.join([str(x) for x in p]),
1374 delim.join([str(x) for x in d]),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001375 ),
James Ward8b390432022-08-12 20:48:56 +01001376 [accum_dtype, s, p, d],
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001377 )
1378 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001379 n += 1
1380
1381 return arg_list
1382
1383 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001384 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1385
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001386 assert isinstance(dtypes, list) or isinstance(
1387 dtypes, tuple
1388 ), f"{dtypes} unexpected"
1389 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01001390
1391 if error_name == ErrorIf.WrongOutputType:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001392 accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
James Ward8b390432022-08-12 20:48:56 +01001393 elif error_name == ErrorIf.WrongInputType:
1394 # Pick some potentially correct output dtype if input type is incorrect
1395 accum_dtype = DType.INT32
1396 else:
Jeremy Johnson1271c442023-09-05 11:39:26 +01001397 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001398
1399 return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])]
1400
1401 @staticmethod
1402 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
1403 # Get valid accumulate type(s)
1404 if dtype == DType.INT8:
1405 accum_dtypes = [DType.INT32]
1406 elif dtype == DType.INT16:
1407 accum_dtypes = [DType.INT48]
1408 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001409 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001410 elif dtype == DType.BF16:
1411 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001412 elif dtype == DType.FP32:
1413 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001414 elif error_name is None:
1415 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
1416
1417 if error_name == ErrorIf.WrongOutputType:
1418 # Get incorrect output dtype for ErrorIf case
Jeremy Johnson1271c442023-09-05 11:39:26 +01001419 accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
James Ward8b390432022-08-12 20:48:56 +01001420 elif error_name == ErrorIf.WrongInputType:
1421 # Pick some potentially correct output dtype if input type is incorrect
1422 accum_dtypes = [DType.INT32]
1423
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001424 # Set up compliance info
1425 args_dict = {
1426 "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
1427 # Set dot_products = N*H*W
1428 "dot_products": gtu.product(
1429 (shapeList[0][0], shapeList[0][1], shapeList[1][2])
1430 ),
1431 }
1432
1433 # Create arg tuple of string and dict
1434 arg_list = []
1435 for a in accum_dtypes:
1436 d = args_dict.copy()
1437 d["acc_type"] = a
1438 arg_list.append((f"acc{testGen.typeStr(a)}", d))
Jeremy Johnson1271c442023-09-05 11:39:26 +01001439
1440 arg_list = TosaArgGen._add_data_generators(
1441 testGen,
1442 opName,
1443 dtype,
1444 arg_list,
1445 error_name,
Jeremy Johnson1271c442023-09-05 11:39:26 +01001446 )
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001447 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson1271c442023-09-05 11:39:26 +01001448 return arg_list
James Ward8b390432022-08-12 20:48:56 +01001449
1450 @staticmethod
1451 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001452 arg_list = []
1453
Jeremy Johnson0c716862023-04-13 17:18:19 +01001454 if testGen.args.level8k and error_name is not None:
1455 # Don't produce negative large tests
1456 return arg_list
1457
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001458 ifm_shape = shapeList[0]
1459 filter_shape = shapeList[1]
1460
Jeremy Johnson1271c442023-09-05 11:39:26 +01001461 accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
James Ward8b390432022-08-12 20:48:56 +01001462
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001463 # Must be rank 4
1464 if error_name != ErrorIf.WrongRank:
1465 assert len(ifm_shape) == 4
1466 assert len(filter_shape) == 4
1467
Jeremy Johnson0c716862023-04-13 17:18:19 +01001468 k_shape = tuple(filter_shape[1:3])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001469
Jeremy Johnson0c716862023-04-13 17:18:19 +01001470 if not testGen.args.level8k:
1471 # Generate comprehensive argument lists
1472 # - except for named errors, which use specific invalid value(s)
1473 smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1
1474 if error_name == ErrorIf.PadLargerEqualKernel:
1475 max_filter_size = -max(k_shape[0], k_shape[1])
1476 p_vals = [
1477 testGen.rng.choice(range(max_filter_size - 10, max_filter_size))
1478 ]
1479 else:
1480 p_vals = [
1481 x
1482 for x in range(
1483 smallest_padding_size, testGen.args.max_conv_padding + 1
1484 )
1485 ]
1486 paddings = {x for x in itertools.product(*([p_vals] * 4))}
1487 if error_name == ErrorIf.StrideSmallerOne:
1488 # Can't use stride=0, as it is used to derive output shape, as a divisor
1489 s_vals = [testGen.rng.choice(range(-5, 0))]
1490 else:
1491 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
1492 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001493
Jeremy Johnson0c716862023-04-13 17:18:19 +01001494 if not error_name and testGen.args.oversize:
1495 # add some oversize argument values
1496 if max(ifm_shape) < 64:
1497 bigPadding = 9
1498 paddings.update(
1499 {
1500 x
1501 for x in itertools.product(
1502 *([[smallest_padding_size, bigPadding]] * 4)
1503 )
1504 }
1505 )
1506 bigStride = 8
1507 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
1508
1509 # There are too many parameter combinations, so generate them sparsely,
1510 # very sparse for negative tests
1511 sparsity_factor = 2 if error_name else 10
1512 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
1513 # If there are only a small number of tests, just select them all
1514 if sparsity < 13:
1515 sparsity = 1
1516 # To get a variety of parameter combinations sparsity should not be a
1517 # multiple of 2, 3 or 5
1518 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1519 sparsity += 1
1520 else:
1521 # Only test 8k levels boundaries
1522 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1523 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1524 bigPadding = bigKernel
1525
1526 pad_shape = [0] * (len(k_shape) * 2)
1527 stride_shape = [1] * len(k_shape)
1528 # The point at which input dimension combined with the stride will
1529 # create large output sizes!
1530 LARGE_SIZE = 2
1531 for idx in range(len(k_shape)):
1532 pad_offset = idx * 2
1533 if k_shape[idx] == bigKernel:
1534 # Set large stride
1535 stride_shape[idx] = bigKernel
1536 # Use negative output padding to reduce shape size
1537 pad_shape[pad_offset] = -(bigPadding - 1)
1538 if ifm_shape[idx + 1] > LARGE_SIZE:
1539 pad_shape[pad_offset + 1] = -(bigPadding - 1)
1540 else:
1541 # The other dimension should be the bigKernel
1542 alt_idx = 1 - idx
1543 if (
1544 k_shape[alt_idx] == bigKernel
1545 and ifm_shape[alt_idx + 1] < LARGE_SIZE
1546 ):
1547 # As the input is small, the large stride won't
1548 # affect the output so we can add some padding
1549 pad_shape[pad_offset + 1] = bigPadding
1550
1551 strides = {tuple(stride_shape)}
1552 paddings = {tuple(pad_shape)}
1553
1554 # Currently allow all combinations that are reasonable size
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001555 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001556
1557 n = 0
1558 for s in sorted(list(strides)):
1559 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07001560 if n % sparsity == 0:
1561 # Determine the output shape
Jeremy Johnson0c716862023-04-13 17:18:19 +01001562 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
1563 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
TatWai Chong24594f52022-06-08 00:48:04 -07001564 os = [ifm_shape[0], oh, ow, filter_shape[0]]
Jeremy Johnson0c716862023-04-13 17:18:19 +01001565
1566 # Support for larger values than 9 needs different delimiter
1567 delim = "" if max(s + p) <= 9 else "x"
TatWai Chong24594f52022-06-08 00:48:04 -07001568 arg_list.append(
1569 (
James Ward8b390432022-08-12 20:48:56 +01001570 "acc{}_st{}_pad{}_os{}".format(
1571 testGen.typeStr(accum_dtype),
Jeremy Johnson0c716862023-04-13 17:18:19 +01001572 delim.join([str(x) for x in s]),
1573 delim.join([str(x) for x in p]),
TatWai Chong24594f52022-06-08 00:48:04 -07001574 "x".join([str(x) for x in os]),
1575 ),
James Ward8b390432022-08-12 20:48:56 +01001576 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001577 )
TatWai Chong24594f52022-06-08 00:48:04 -07001578 )
1579 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001580
1581 return arg_list
1582
1583 @staticmethod
1584 def agPad(testGen, opName, shapeList, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001585 rank = len(shapeList[0])
1586
1587 # Exhaustively test combinations of padding on each side of each dimension
1588 # - the range of padding values is defined by pad_min and pad_max
1589 # - for padding >9, the name format needs to be more distinctive
1590 pad_min, pad_max = 0, 1
1591 pad_values = [x for x in range(pad_min, pad_max + 1)]
1592 if error_name == ErrorIf.PadSmallerZero:
1593 pad_values = [x for x in range(-2, 0)]
1594 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
1595 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
1596
1597 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
1598 pad_const_int = testGen.getRandNumberDType(dtype)
1599 pad_const_fp = 0
James Wardf0890992022-11-17 11:15:14 +00001600 elif dtype in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001601 pad_const_int = 0
1602 pad_const_fp = testGen.getRandNumberDType(dtype)
1603 else:
1604 return []
1605
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001606 list_shape_pad_values = list(shape_pad_values)
1607 # If we are producing tests for rank 6 or greater use sparsity
1608 if len(list_shape_pad_values) > 1024:
1609 sparsity_factor = 2 if error_name else 120
1610 sparsity = TosaArgGen._calculate_sparsity(
1611 len(list_shape_pad_values), sparsity_factor
1612 )
1613 else:
1614 sparsity = 1
1615
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001616 # Build arg list
1617 arg_list = []
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001618 for n, paddings in enumerate(list_shape_pad_values):
James Ward8b390432022-08-12 20:48:56 +01001619 paddings = list(paddings)
1620 args_valid = True
1621
1622 if error_name == ErrorIf.PadSmallerZero:
1623 # Prevent negative output shapes while ensuring still testing for negative padding
1624 for i in range(rank):
1625 dim_after_padding = (
1626 paddings[i][0] + paddings[i][1] + shapeList[0][i]
1627 )
1628 if dim_after_padding < 1:
1629 paddings[i] = (0, 0)
1630 if all([p > -1 for p in paddings[i]]):
1631 args_valid = False
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00001632 if args_valid and n % sparsity == 0:
James Ward8b390432022-08-12 20:48:56 +01001633 name = "pad"
1634 for r in range(rank):
1635 before, after = paddings[r]
1636 name = f"{name}{before}{after}"
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001637 args_dict = {
1638 "pad": np.array(paddings),
1639 "pad_const_int": pad_const_int,
1640 "pad_const_fp": pad_const_fp,
1641 }
1642 arg_list.append((name, args_dict))
James Ward8b390432022-08-12 20:48:56 +01001643
1644 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
1645 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001646
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001647 arg_list = TosaArgGen._add_data_generators(
1648 testGen,
1649 opName,
1650 dtype,
1651 arg_list,
1652 error_name,
1653 )
1654
1655 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001656 return arg_list
1657
1658 @staticmethod
1659 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
1660 arg_list = []
1661
1662 shape = shapeList[0]
1663 if error_name != ErrorIf.WrongRank:
1664 assert len(shape) == 4
1665
Jeremy Johnson0c716862023-04-13 17:18:19 +01001666 test_level8k = testGen.args.level8k and error_name is None
1667
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001668 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
Jeremy Johnson0c716862023-04-13 17:18:19 +01001669 startKernel = 2
1670 startPad = 0
1671 if not test_level8k:
1672 # Generate comprehensive argument lists
1673 p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)]
1674 paddings = {x for x in itertools.product(*([p_vals] * 4))}
1675 # Stride must be greater than 1 to force non-integer error
1676 s_vals = [
1677 x for x in range(startStride, testGen.args.max_pooling_stride + 1)
1678 ]
1679 strides = {x for x in itertools.product(*([s_vals] * 2))}
1680 k_vals = [
1681 x for x in range(startKernel, testGen.args.max_pooling_kernel + 1)
1682 ]
1683 kernels = {x for x in itertools.product(*([k_vals] * 2))}
1684 max_dim_size = None
1685 else:
1686 # Only test 8k levels
1687 bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
1688 bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL
1689 strides = {(1, bigStride), (bigStride, 4)}
1690 kernels = {(1, bigKernel), (bigKernel, 3)}
1691 paddings = set()
1692 for s in sorted(list(strides)):
1693 for k in sorted(list(kernels)):
1694 padding = []
1695 for idx in range(len(k)):
1696 total_padding = s[idx] - shape[idx + 1] + k[idx]
1697 while total_padding < 0:
1698 # Must meet: shape + padding > kernel
1699 total_padding += s[idx]
1700 if total_padding < k[idx]:
1701 padding.extend([0, total_padding])
1702 else:
1703 # Note this may produce padding >= k[idx] which is not
1704 # allowed - but will be ignored in the creation loop below
1705 padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)])
1706 paddings.add(tuple(padding))
1707 # Create a limit for the output dimensions size
1708 max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001709
James Ward8b390432022-08-12 20:48:56 +01001710 if opName == "max_pool2d":
1711 accum_dtypes = [None] # max_pool has no accumulate dtype
1712 elif dtype == DType.INT8 or dtype == DType.INT16:
1713 accum_dtypes = [DType.INT32]
1714 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001715 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001716 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001717 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001718 elif error_name is None:
1719 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
1720 else:
1721 # Set to something for the ErrorIf case which has
1722 # incorrect input data-type
1723 accum_dtypes = [DType.INT32]
1724
Jeremy Johnson0c716862023-04-13 17:18:19 +01001725 if not test_level8k:
1726 if testGen.args.oversize:
1727 # add some oversize argument values
1728 bigStride = 7
1729 bigKernel = 9
1730 strides.update(
1731 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001732 )
Jeremy Johnson0c716862023-04-13 17:18:19 +01001733 kernels.update(
1734 {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))}
1735 )
1736 if max(shape) < 64:
1737 # padding must be less than the kernel size
1738 bigPadding = bigKernel - 1
1739 paddings.update(
1740 {x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
1741 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001742
Jeremy Johnson0c716862023-04-13 17:18:19 +01001743 # There are too many parameter combinations, so generate them sparsely,
1744 # very sparse for negative tests
1745 sparsity_factor = 2 if error_name else 500
1746 sparsity = (
1747 len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
1748 )
1749 else:
1750 # We have already limited test output combinations for 8k tests
1751 sparsity = 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001752
James Ward8b390432022-08-12 20:48:56 +01001753 arg_str = (
1754 "acc{}_st{}_kern{}_pad{}"
1755 if accum_dtypes[0] is not None
1756 else "st{}_kern{}_pad{}"
1757 )
1758
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001759 def get_arg_list_element(accum, stride, pad, kern, dot_products=0):
James Ward8b390432022-08-12 20:48:56 +01001760 # Return tuple containing the formatted argument string and
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001761 # the corresponding argument values in a dictionary
Jeremy Johnson0c716862023-04-13 17:18:19 +01001762
1763 # Support for larger values than 9 needs different delimiter
1764 delim = "" if max(stride + kern + pad) <= 9 else "x"
James Ward8b390432022-08-12 20:48:56 +01001765 arg_str_elems = [
Jeremy Johnson0c716862023-04-13 17:18:19 +01001766 delim.join([str(x) for x in stride]),
1767 delim.join([str(x) for x in kern]),
1768 delim.join([str(x) for x in pad]),
James Ward8b390432022-08-12 20:48:56 +01001769 ]
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001770 args_dict = {
1771 "stride": stride,
1772 "pad": pad,
1773 "kernel": kern,
1774 "dot_products": dot_products, # Ignored for error tests
1775 "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY
1776 }
James Ward8b390432022-08-12 20:48:56 +01001777
1778 if accum is not None:
1779 arg_str_elems.insert(0, testGen.typeStr(accum))
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001780 args_dict["acc_type"] = accum
1781 return (arg_str.format(*arg_str_elems), args_dict)
James Ward8b390432022-08-12 20:48:56 +01001782
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001783 n = 0
James Ward8b390432022-08-12 20:48:56 +01001784 for a in accum_dtypes:
1785 for s in sorted(list(strides)):
1786 for p in sorted(list(paddings)):
1787 for k in sorted(list(kernels)):
1788 if error_name in [
1789 ErrorIf.StrideSmallerOne,
1790 ErrorIf.KernelSmallerOne,
1791 ErrorIf.PadSmallerZero,
1792 ErrorIf.PadLargerEqualKernel,
1793 ]:
1794 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
1795 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001796 )
James Ward8b390432022-08-12 20:48:56 +01001797 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001798 arg_list.append(
1799 get_arg_list_element(a, sNew, pNew, kNew)
1800 )
James Ward8b390432022-08-12 20:48:56 +01001801 elif (
1802 n % sparsity == 0
1803 # padding must not exceed the kernel size
1804 and p[0] < k[0]
1805 and p[1] < k[0]
1806 and p[2] < k[1]
1807 and p[3] < k[1]
1808 # the padded shape must exceed the kernel size
1809 and (shape[1] + p[0] + p[1]) > k[0]
1810 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001811 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001812 partial_h = shape[1] + p[0] + p[1] - k[0]
1813 partial_w = shape[2] + p[2] + p[3] - k[1]
1814 remainder_h = partial_h % s[0]
1815 remainder_w = partial_w % s[1]
1816 output_h = partial_h // s[0] + 1
1817 output_w = partial_w // s[1] + 1
1818 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
James Ward8b390432022-08-12 20:48:56 +01001819 if (
1820 # the parameters must produce integer exact output
1821 error_name != ErrorIf.PoolingOutputShapeNonInteger
1822 and remainder_h == 0
1823 and remainder_w == 0
1824 ) or (
1825 error_name == ErrorIf.PoolingOutputShapeNonInteger
1826 and (remainder_h != 0 or remainder_w != 0)
1827 ):
Jeremy Johnson0c716862023-04-13 17:18:19 +01001828 if (
1829 max_dim_size is not None
1830 and max(output_h, output_w) > max_dim_size
1831 ):
1832 # Test will consume too much memory - skip it
1833 continue
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001834 # Dot products = N*OH*OW*C
1835 dp = gtu.product(
1836 (shape[0], output_h, output_w, shape[3])
1837 )
1838 arg_list.append(get_arg_list_element(a, s, p, k, dp))
James Ward8b390432022-08-12 20:48:56 +01001839 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001840
Jeremy Johnsond41feb72023-10-12 16:03:15 +01001841 # Now add data generator types
1842 arg_list = TosaArgGen._add_data_generators(
1843 testGen,
1844 opName,
1845 dtype,
1846 arg_list,
1847 error_name,
1848 )
1849
1850 # Return list of tuples: (arg_str, args_dict)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001851 return arg_list
1852
1853 @staticmethod
1854 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
1855 arg_list = []
1856
1857 # Enumerate the output types here
1858 if error_name == ErrorIf.WrongOutputType:
1859 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
1860 elif inDtype == DType.INT8:
James Ward736fd1a2023-01-23 17:13:37 +00001861 dtypeList = [
1862 DType.BOOL,
1863 DType.INT16,
1864 DType.INT32,
1865 DType.FP16,
1866 DType.BF16,
1867 DType.FP32,
1868 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001869 elif inDtype == DType.INT16:
James Ward736fd1a2023-01-23 17:13:37 +00001870 dtypeList = [
1871 DType.BOOL,
1872 DType.INT8,
1873 DType.INT32,
1874 DType.FP16,
1875 DType.BF16,
1876 DType.FP32,
1877 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001878 elif inDtype == DType.INT32:
James Ward736fd1a2023-01-23 17:13:37 +00001879 dtypeList = [
1880 DType.BOOL,
1881 DType.INT8,
1882 DType.INT16,
1883 DType.FP16,
1884 DType.BF16,
1885 DType.FP32,
1886 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001887 elif inDtype == DType.BOOL:
1888 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01001889 elif inDtype == DType.FP16:
James Ward736fd1a2023-01-23 17:13:37 +00001890 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001891 elif inDtype == DType.BF16:
James Ward736fd1a2023-01-23 17:13:37 +00001892 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001893 elif inDtype == DType.FP32:
James Ward736fd1a2023-01-23 17:13:37 +00001894 dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001895 elif error_name == ErrorIf.WrongInputType:
1896 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001897 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001898 else:
1899 raise Exception("Unexpected input dtype: {}".format(inDtype))
1900
1901 for dtype in dtypeList:
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001902 arg_list.append(("out{}".format(testGen.typeStr(dtype)), [dtype]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001903
1904 return arg_list
1905
1906 @staticmethod
1907 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
1908 arg_list = []
1909
1910 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001911 for outDtype in [
1912 DType.UINT8,
1913 DType.INT8,
1914 DType.INT16,
1915 DType.INT32,
1916 DType.UINT16,
1917 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001918 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001919 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001920 and error_name == ErrorIf.OutputZeroPointNotZero
1921 ):
1922 continue
1923 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001924 outDtype != DType.UINT16
1925 and error_name == ErrorIf.U16OutputZeroPointNotValid
1926 ) or (
1927 inDtype != DType.UINT16
1928 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001929 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001930 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001931 continue
1932 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001933 inDtype == DType.UINT8
1934 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001935 and error_name != ErrorIf.WrongOutputType
1936 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001937 # The only output dtypes for UINT8 are INT8/INT16, skip all others
1938 continue
1939 if (
1940 inDtype not in [DType.INT8, DType.INT16]
1941 and outDtype == DType.UINT8
1942 and error_name != ErrorIf.WrongOutputType
1943 ):
1944 # The only input dtypes for UINT8 are INT8/INT16, skip all others
1945 continue
1946 if (
1947 inDtype == DType.UINT16
1948 and outDtype != DType.INT16
1949 and error_name != ErrorIf.WrongOutputType
1950 ):
1951 # The only output dtype for UINT16 is INT16, skip all others
1952 continue
1953 if (
1954 inDtype != DType.INT16
1955 and outDtype == DType.UINT16
1956 and error_name != ErrorIf.WrongOutputType
1957 ):
1958 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001959 continue
1960 if (
1961 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001962 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001963 ):
1964 continue
1965
1966 for scale32 in [False, True]:
1967 if error_name == ErrorIf.ScaleTrue and not scale32:
1968 continue
1969 elif error_name == ErrorIf.ScaleNotTrue and scale32:
1970 continue
1971 for double_round in [False, True]:
1972 if error_name == ErrorIf.ScaleNotTrue and not double_round:
1973 continue
1974 for per_channel in [False, True]:
1975
1976 if (
1977 inDtype == DType.INT48
1978 and scale32
1979 and error_name != ErrorIf.ScaleTrue
1980 ):
1981 # Illegal condition. Must be scale32=False
1982 continue
1983 if (
1984 double_round
1985 and not scale32
1986 and error_name != ErrorIf.ScaleNotTrue
1987 ):
1988 # Illegal condition. ERROR_IF(!scale32 && double_round)
1989 continue
1990
1991 arg_list.append(
1992 (
1993 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001994 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001995 int(scale32),
1996 int(double_round),
1997 int(per_channel),
1998 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001999 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002000 )
2001 )
2002
2003 return arg_list
2004
2005 @staticmethod
2006 def agMul(testGen, opName, shapeList, dtype, error_name=None):
2007 arg_list = []
2008
2009 if dtype is DType.INT32:
2010 for p in range(testGen.args.num_rand_permutations):
2011
2012 shift = testGen.randInt(0, 32)
2013
2014 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
2015 else:
2016 arg_list.append(("perm0_shift0", [0]))
2017
2018 return arg_list
2019
2020 @staticmethod
2021 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
2022 arg_list = []
2023
2024 arg_list.append(("roundTrue", [True]))
2025 arg_list.append(("roundFalse", [False]))
2026
2027 return arg_list
2028
Luke Hutton57287132023-02-06 14:54:18 +00002029 @staticmethod
2030 def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
2031 arg_list = []
2032
2033 arg_list.append(("inverseTrue", [True]))
2034 arg_list.append(("inverseFalse", [False]))
2035
2036 return arg_list
2037
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002038 # Helper function for reshape. Gets some factors of a larger number.
2039 @staticmethod
2040 def getFactors(val, start=1):
2041 factors = []
2042
2043 for i in range(start, int(np.sqrt(val)) + 1):
2044 if (val % i) == 0:
2045 factors.append(i)
2046
2047 return factors
2048
2049 @staticmethod
2050 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
2051 arg_list = []
2052
2053 origShape = shapeList[0]
2054
2055 totalElements = 1
2056 for s in origShape:
2057 totalElements *= s
2058
2059 # This code is NOT fast. Fortunately, the numbers are fairly small.
2060 factors = TosaArgGen.getFactors(totalElements)
2061
2062 for p in range(testGen.args.num_rand_permutations):
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +00002063 # Rank from 1 to TOSA_TENSOR_MAX_RANK
2064 newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002065 if len(factors) < newRank:
2066 continue
2067
2068 found = True
2069 # escape_counter breaks while loop if it continues on for too long
2070 escape_counter = 0
2071 while found:
2072 newShape = []
Jerry Ge264f7fa2023-04-21 22:49:57 +00002073 new_shape_inferred = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002074 # Generate newShape ensuring it isn't a duplicate
2075 remainingElements = totalElements
2076 shuffledFactors = testGen.rng.permutation(factors)
Jerry Ge264f7fa2023-04-21 22:49:57 +00002077 inferred_dim = testGen.rng.integers(1, newRank + 1)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002078 for i in range(1, newRank):
2079 # pick rank-1 factors
2080 newShape.append(shuffledFactors[0])
2081 remainingElements = remainingElements // shuffledFactors[0]
Jerry Ge264f7fa2023-04-21 22:49:57 +00002082 if i == inferred_dim:
2083 new_shape_inferred.append(-1)
2084 else:
2085 new_shape_inferred.append(shuffledFactors[0])
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002086 shuffledFactors = testGen.rng.permutation(
2087 TosaArgGen.getFactors(remainingElements)
2088 )
2089 newShape.append(remainingElements)
Jerry Ge264f7fa2023-04-21 22:49:57 +00002090 if inferred_dim == newRank:
2091 new_shape_inferred.append(-1)
2092 else:
2093 new_shape_inferred.append(remainingElements)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002094
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002095 # Check for duplicates
2096 found = False
2097 for name, other_shape in arg_list:
2098 if other_shape[0] == newShape:
2099 found = True
2100 break
2101
2102 escape_counter += 1
2103 if escape_counter >= 100:
2104 break
2105
2106 if not found:
Jerry Ge264f7fa2023-04-21 22:49:57 +00002107 if error_name in [
2108 ErrorIf.ReshapeOutputSizeNonInteger,
2109 ErrorIf.ReshapeOutputSizeMultiInference,
2110 ]:
2111 if newRank < 2:
2112 # Need at least two dimensions
2113 continue
2114 # NOTE: Change inferred_dim starting offset from 1 to 0
2115 inferred_dim -= 1
2116 extra_dim = inferred_dim + testGen.rng.integers(1, newRank)
2117 extra_dim = extra_dim % newRank
2118 assert extra_dim != inferred_dim
2119 if error_name == ErrorIf.ReshapeOutputSizeNonInteger:
2120 elements = 1
2121 for i, dim_value in enumerate(new_shape_inferred):
2122 if i != inferred_dim and i != extra_dim:
2123 elements *= dim_value
2124 dim_value = new_shape_inferred[extra_dim]
2125 while totalElements % (elements * dim_value) == 0:
2126 dim_value += 1
2127 new_shape_inferred[extra_dim] = dim_value
2128 else:
2129 assert error_name == ErrorIf.ReshapeOutputSizeMultiInference
2130 new_shape_inferred[extra_dim] = -1
2131 else:
2132 arg_list.append(
2133 ("perm{}_rank{}_outdefined".format(p, newRank), [newShape])
2134 )
2135 if error_name != ErrorIf.TensorSizeInputOutputMismatch:
2136 arg_list.append(
2137 (
2138 "perm{}_rank{}_outinferred".format(p, newRank),
2139 [new_shape_inferred],
2140 )
2141 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002142
2143 return arg_list
2144
2145 @staticmethod
2146 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
2147 arg_list = []
2148
2149 ifm_shape = shapeList[0]
2150
2151 if error_name == ErrorIf.IndexOutsideBounds:
2152 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
2153 incorrect_small_index = range(-len(ifm_shape), 0)
2154 permutations = [p for p in itertools.permutations(incorrect_large_index)]
2155 permutations.extend(
2156 [p for p in itertools.permutations(incorrect_small_index)]
2157 )
2158 elif error_name == ErrorIf.IndexUsedTwice:
2159 # Create list with a duplicated index
2160 perm_range = list(range(len(ifm_shape)))
2161 index_choice = testGen.rng.choice(range(len(perm_range)))
2162 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
2163 permutations = [p for p in itertools.permutations(perm_range)]
2164
2165 else:
2166 # Get all permutations
2167 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
2168
2169 # Limit to possible permutations from shape dimension or argument setting
2170 limit = min(len(permutations), testGen.args.num_rand_permutations)
2171
2172 # Get random permutation generator that uses all permutations
2173 random_permutations = testGen.rng.permutation(permutations)
2174
2175 # Create list of required amount of permutations
2176 arg_list = [
2177 ("perm{}".format(p), [random_permutations[p].tolist()])
2178 for p in range(limit)
2179 ]
2180 return arg_list
2181
2182 @staticmethod
2183 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
2184 arg_list = []
2185
2186 ifm_shape = shapeList[0]
2187 rank = len(ifm_shape)
2188
2189 for p in range(testGen.args.num_rand_permutations):
2190 start = []
2191 size = []
2192
2193 valid = True
2194
2195 for i in range(rank):
2196 if ifm_shape[i] > 1:
2197 start.append(testGen.randInt(0, ifm_shape[i]))
2198 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
2199
2200 # Invalid slice size?
2201 if size[i] == 0:
2202 valid = False
2203 else:
2204 start.append(0)
2205 size.append(1)
2206
2207 if valid:
2208 # If ERROR_IF test required then incorrect start, size will be returned
2209 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
2210 testGen, error_name, ifm_shape, start, size
2211 )
2212 arg_list.append(("perm{}".format(p), [start, size]))
2213 return arg_list
2214
2215 @staticmethod
2216 def agTile(testGen, opName, shapeList, dtype, error_name=None):
2217 arg_list = []
2218
2219 ifm_shape = shapeList[0]
2220 rank = len(ifm_shape)
2221
2222 for p in range(testGen.args.num_rand_permutations):
2223
2224 # Pick a few random, but small multiple values
2225 # because otherwise this has a tendency to generate
2226 # enormous tensors
2227 multiples = []
2228 for i in range(rank):
2229 if ifm_shape[i] > 1000:
2230 # Multiple of 1 if ifm_shape dimension is large to reduce
2231 # tensor size
2232 multiples.append(1)
2233 elif max(ifm_shape) > 1000:
2234 multiples.append(2)
2235 else:
2236 multiples.append(testGen.randInt(1, 4))
2237 arg_list.append(("perm{}".format(p), [multiples]))
2238
2239 return arg_list
2240
2241 @staticmethod
2242 def agResize(testGen, opName, shapeList, dtype, error_name=None):
2243 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002244 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002245
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002246 def get_aspect_ratio_resize_params():
2247 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
2248 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
2249 invert = testGen.rng.choice((False, True))
2250 letterbox = testGen.rng.choice((False, True))
2251
2252 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
2253 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
2254 scale_y_d = scale_x_d = 1
2255 offset_x = offset_y = 0
2256
2257 if letterbox:
2258 max_border = scale_y_n
2259 border_y = testGen.randInt(low=0, high=max_border)
2260 border_x = 0
2261 else:
2262 # Pillarboxing
2263 border_y = 0
2264 max_border = scale_x_n
2265 border_x = testGen.randInt(low=0, high=max_border)
2266
2267 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2268 offset = (offset_y, offset_x)
2269 border = (border_y, border_x)
2270
2271 return scale, offset, border
2272
2273 def get_upscale_downscale_params():
2274 valid_params = False
2275 while not valid_params:
2276 upscale = testGen.rng.choice((False, True))
2277
2278 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
2279 origin_sampling = testGen.rng.choice((False, True))
2280
2281 if upscale:
2282 shift = testGen.randInt(low=1, high=4)
2283 scale_x_d = scale_y_d = 1
2284 scale_x_n = scale_y_n = (
2285 1 << shift if origin_sampling else 2 << shift
2286 )
2287 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
2288 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
2289 else:
2290 scale_x_n = 1
2291 scale_y_n = 1
2292
2293 # Return list of valid scale_*_d values (max value 4) given input dim shape
2294 def get_valid_denom(ifm_dim):
2295 return [x for x in range(1, 5) if ifm_dim % x == 1]
2296
2297 # Generate list of valid downscale values and choose one randomly
2298 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
2299 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
2300
2301 if not valid_scale_y_ds and not valid_scale_x_ds:
2302 # Bad parameters, skip
2303 continue
2304
2305 if not valid_scale_y_ds:
2306 scale_y_d = 1
2307 else:
2308 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
2309
2310 if not valid_scale_x_ds:
2311 scale_x_d = 1
2312 else:
2313 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
2314
2315 border_x = border_y = 0
2316 offset_y = testGen.randInt(0, 16 * scale_y_n)
2317 offset_x = testGen.randInt(0, 16 * scale_x_n)
2318 valid_params = True
2319
2320 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2321 offset = (offset_y, offset_x)
2322 border = (border_y, border_x)
2323 return scale, offset, border
2324
2325 def get_rand_params():
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002326 def fix_scale_to_max_scale(scale_n, scale_d, max_scale):
2327 scale = scale_n / scale_d
2328 if scale > max_scale:
2329 factor = scale / max_scale
2330 new_scale_d = math.ceil(scale_d * factor)
2331 assert scale_n / new_scale_d <= max_scale
2332 scale_d = new_scale_d
2333 return scale_d
2334
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002335 # Scale
2336 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
2337 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
2338
2339 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
2340 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
2341
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002342 scale_y_d = fix_scale_to_max_scale(
2343 scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2344 )
2345 scale_x_d = fix_scale_to_max_scale(
2346 scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE
2347 )
2348
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002349 # Offsets and border within the scale
2350 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
2351 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
2352 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
2353 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
2354
2355 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
2356 offset = (offset_y, offset_x)
2357 border = (border_y, border_x)
2358 return scale, offset, border
2359
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002360 def get_level_8k_params():
2361 # Create 64x scale - 64/1 to 2048/32
2362 scale_d = testGen.randInt(
2363 low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE
2364 )
2365 scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE
2366 # Create half to fifth scaling
2367 scale_d_alt = testGen.randInt(low=2, high=6)
2368 scale_n_alt = 1
2369 switch = testGen.rng.choice((False, True))
2370 if switch:
2371 scale = (scale_n_alt, scale_d_alt, scale_n, scale_d)
2372 else:
2373 scale = (scale_n, scale_d, scale_n_alt, scale_d_alt)
2374
2375 offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1))
2376 offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1))
2377 offset = (offset_y, offset_x)
2378 border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1))
2379 border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1))
2380 border = (border_y, border_x)
2381 return scale, offset, border
2382
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002383 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002384 # Exclude illegal {mode, type} configurations. Pick legal output types
2385 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
2386 outputDTypeList = [DType.INT8]
2387 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
2388 outputDTypeList = [DType.INT16]
2389 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
2390 outputDTypeList = [DType.INT32]
2391 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
2392 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01002393 elif dtype == DType.FP16:
2394 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01002395 elif dtype == DType.BF16:
2396 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01002397 elif dtype == DType.FP32:
2398 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002399 elif error_name == ErrorIf.WrongInputType:
2400 # If an incorrect input type is used then we set a 'correct'
2401 # output type to avoid other errors
2402 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
2403 else:
2404 continue
2405
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002406 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
2407
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002408 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002409 perm = 0
2410 while perm < testGen.args.num_rand_permutations:
2411 # Random choice of type of params we are testing
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002412 if not testGen.args.level8k:
2413 _rnd_param_fn = testGen.rng.choice(
2414 (
2415 get_rand_params,
2416 get_upscale_downscale_params,
2417 get_aspect_ratio_resize_params,
2418 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002419 )
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002420 scale, offset, border = _rnd_param_fn()
2421 else:
2422 scale, offset, border = get_level_8k_params()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002423
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002424 # Expand params for bounds-checking
2425 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
2426 (offset_y, offset_x) = offset
2427 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002428
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002429 # Make sure output dimensions OH and OW are integers
2430 partial_output_y = (
2431 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
2432 )
2433 partial_output_x = (
2434 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
2435 )
2436 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002437 # Look for non-integer test
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002438 if (
2439 partial_output_y % scale_y_d == 0
2440 and partial_output_x % scale_x_d == 0
2441 ):
2442 # Skip this test as it doesn't produce NonInteger output
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002443 if perm > 0:
2444 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002445 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002446 else:
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002447 # Alter the scaling factors to make the output integer
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002448 while partial_output_y % scale_y_d != 0:
2449 scale_y_d -= 1
2450 while partial_output_x % scale_x_d != 0:
2451 scale_x_d -= 1
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002452 # Make sure we are still within max scaling
2453 if (
2454 scale_y_n / scale_y_d
2455 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or (
2456 scale_x_n / scale_x_d
2457 ) > testGen.TOSA_8K_LEVEL_MAX_SCALE:
2458 # Skip the test as it is using too large a scaling factor
2459 if perm > 0:
2460 perm += 1
2461 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002462
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002463 output_y = partial_output_y // scale_y_d + 1
2464 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002465
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002466 if (
2467 output_y >= testGen.args.max_resize_output_dim
2468 or output_x >= testGen.args.max_resize_output_dim
2469 ) and error_name is None:
2470 # Skip positive test if output dim will be too high
2471 # Avoid high test latency and OOM issues
Jeremy Johnsonb2099702023-04-12 15:59:01 +01002472 if not testGen.args.level8k or perm > 0:
2473 perm += 1
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002474 continue
2475
2476 if (
2477 output_y <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01002478 or output_y >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002479 or output_x <= 0
Jeremy Johnson1271c442023-09-05 11:39:26 +01002480 or output_x >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002481 ):
2482 # Output dimensions out of scope
2483 if error_name is not None and perm > 0:
2484 # As long as we have one ERROR_IF test, don't worry
2485 # about creating all the other permutations
2486 perm += 1
2487 continue
2488
2489 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
2490 (
Jeremy Johnson1271c442023-09-05 11:39:26 +01002491 output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002492 and output_y - scale_y_d < 1
2493 )
2494 or (
Jeremy Johnson1271c442023-09-05 11:39:26 +01002495 output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002496 and output_x - scale_x_d < 1
2497 )
2498 ):
2499 # Can't create a negative test with these params as it
2500 # will create invalid output size
2501 if perm > 0:
2502 perm += 1
2503 continue
2504
2505 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
2506 offset = [offset_y, offset_x]
2507 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002508
2509 # Common for all data types
2510 if error_name is not None:
2511 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002512 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002513 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002514 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002515 outputDTypeNew,
2516 ) = TosaErrorIfArgGen.eiResizeErrorIf(
2517 testGen,
2518 error_name,
2519 mode,
2520 dtype,
2521 shapeList,
2522 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002523 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002524 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002525 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002526 )
2527 else:
2528 outputDTypeNew = outputDType
2529
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002530 arg_to_append = (
2531 arg_str.format(
2532 "N" if mode == ResizeMode.NEAREST else "B",
2533 testGen.typeStr(outputDTypeNew),
2534 scale[0],
2535 scale[1],
2536 scale[2],
2537 scale[3],
2538 offset[0],
2539 offset[1],
2540 border[0],
2541 border[1],
2542 ),
2543 [
2544 mode,
2545 scale,
2546 offset,
2547 border,
2548 dtype,
2549 outputDTypeNew,
2550 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002551 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002552 if arg_to_append in arg_list:
2553 # Skip already generated test params
2554 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002555
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01002556 # Valid permutation
2557 perm += 1
2558 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002559 return arg_list
2560
2561 @staticmethod
2562 def agTable(testGen, opName, shapeList, dtype, error_name=None):
2563 arg_list = []
2564
2565 if dtype == DType.INT8:
2566 table = np.int32(
2567 testGen.rng.integers(low=-128, high=128, size=[256])
2568 ).tolist()
2569 else: # INT16
2570 table = np.int32(
2571 testGen.rng.integers(low=-32768, high=32768, size=[513])
2572 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07002573 # Make sure all slopes are within REQUIRE min/max 16-bit int
2574 for idx in range(len(table) - 1):
2575 slope = table[idx + 1] - table[idx]
2576 # Alter the next table entry to force the slope to be ok
2577 if slope > 32767:
2578 table[idx + 1] -= slope - 32767
2579 if slope < -32768:
2580 table[idx + 1] -= slope + 32768
2581 slope = table[idx + 1] - table[idx]
2582 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002583 arg_list.append(
2584 (
2585 "",
2586 [table],
2587 )
2588 )
2589 return arg_list
2590
2591 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
2592 # CondIf generates the condition values here.
2593 # Convert to tensors in the build function, along with the
2594 # then and else blocks
2595 arg_list = []
2596
2597 for c in [False, True]:
2598 arg_list.append(("cond{}".format(int(c)), [c]))
2599
2600 return arg_list
2601
2602 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
2603 # While loop: 0 iterations, 1, more than 1
2604 arg_list = []
2605
2606 for iter in [0, 1, 4]:
2607 arg_list.append(("iter{}".format(iter), [iter]))
2608
2609 return arg_list