blob: 932ad55c9214d2e194e331a02a986143bb9310df [file] [log] [blame]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001# Copyright (c) 2021-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3import itertools
4import math
James Ward8b390432022-08-12 20:48:56 +01005import warnings
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01006
7import numpy as np
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01008from generator.tosa_error_if import ErrorIf
9from generator.tosa_error_if import TosaErrorIfArgGen
James Ward8b390432022-08-12 20:48:56 +010010from generator.tosa_utils import get_accum_dtype_from_tgTypes
11from generator.tosa_utils import get_wrong_output_type
Jeremy Johnsona0e03f32022-06-13 17:48:09 +010012from generator.tosa_utils import MAX_RESIZE_DIMENSION
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010013from serializer.tosa_serializer import DTypeNames
14from tosa.DType import DType
15from tosa.Op import Op
16from tosa.ResizeMode import ResizeMode
17
18# DTypeNames, DType, Op and ResizeMode are convenience variables to the
19# flatc-generated types that should be enums, but aren't
20
21
22class TosaQuantGen:
23 """QuantizedInfo random generator helper functions.
24
25 Specify with 'qgen': in the operator defintion.
26 """
27
28 def __init__(self):
29 pass
30
31 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +000032 def getZeroPoint(testGen, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010033
34 if dtype == DType.INT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010035 if testGen.args.zeropoint is not None:
36 return min(127, max(-128, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010037 return testGen.randInt(-128, 128)
38 elif dtype == DType.UINT8:
Jeremy Johnson00423432022-09-12 17:27:37 +010039 if testGen.args.zeropoint is not None:
40 return min(255, max(0, testGen.args.zeropoint))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010041 return testGen.randInt(0, 256)
42 elif error_name in [
43 ErrorIf.InputZeroPointNotZero,
44 ErrorIf.WeightZeroPointNotZero,
45 ErrorIf.OutputZeroPointNotZero,
46 ]:
47 zero_point = testGen.randInt(-128, 128)
48 if zero_point == 0:
49 zero_point = 1
50 return zero_point
51 return 0
52
53 @staticmethod
54 def qgUnary(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010055 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000056 qinfo = [
57 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
58 TosaQuantGen.getZeroPoint(testGen, dtype),
59 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010060 elif error_name == ErrorIf.OutputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000061 qinfo = [
62 TosaQuantGen.getZeroPoint(testGen, dtype),
63 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
64 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010065 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000066 qinfo = [
67 TosaQuantGen.getZeroPoint(testGen, dtype),
68 TosaQuantGen.getZeroPoint(testGen, dtype),
69 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010070 return qinfo
71
72 @staticmethod
73 def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010074 if isinstance(dtype_or_dtypeList, list):
75 # a list of [input, weights, accumulator] dtypes
76 dtypeList = dtype_or_dtypeList
77 else:
78 # an int, [input, weights, accumulator] dtypes are the same
79 dtypeList = [dtype_or_dtypeList] * 3
80
81 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000082 qinfo = [
83 TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
84 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
85 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010086 elif error_name == ErrorIf.WeightZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000087 qinfo = [
88 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
89 TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
90 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010091 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +000092 qinfo = [
93 TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
94 TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
95 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +010096 return qinfo
97
98 @staticmethod
99 def qgMatmul(testGen, op, dtype, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100100 if error_name == ErrorIf.InputZeroPointNotZero:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000101 qinfo = [
102 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
103 TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
104 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100105 else:
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000106 qinfo = [
107 TosaQuantGen.getZeroPoint(testGen, dtype),
108 TosaQuantGen.getZeroPoint(testGen, dtype),
109 ]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100110 return qinfo
111
112 @staticmethod
113 def computeMultiplierAndShift(scaleFp, scale32):
114 # Derived from computeMultiplierAndShiftTosaScale32
115 # Provide a floating-point scaling factor and the scale32 parameter
116 # to compute the multiplier and shift
117
118 if scale32:
119 scaleBits = 31
120 else:
121 scaleBits = 15
122
123 m, shift = math.frexp(scaleFp)
124
125 if scaleFp < 0.0:
126 m = -m
127
128 multiplier = round(m * (1 << scaleBits))
129 assert multiplier <= (1 << scaleBits)
130
131 if multiplier == (1 << scaleBits):
132 multiplier = multiplier // 2
133 shift = shift + 1
134
135 shift = (-shift) + scaleBits
136 # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
137 # scaleFp, scaleBits, m, multiplier, shift))
138
139 # Adjust multiplier such that shift is in allowed value range.
140 if shift == 0:
141 multiplier = multiplier // 4
142 shift = shift + 2
143 elif shift == 1:
144 multiplier = multiplier // 2
145 shift = shift + 1
146 elif shift == 63:
147 multiplier = multiplier * 2
148 shift = shift - 1
149
150 assert multiplier <= (1 << scaleBits)
151 assert shift >= 2 and shift <= 62
152
153 return multiplier, shift
154
155
156class TosaTensorGen:
157 """Tensor generators create a shape list for the placeholder and const tensor
158 data operands for the operator.
159
160 The actual random data is generated separately for each test.
161 """
162
163 def __init__(self):
164 pass
165
166 @staticmethod
167 def tgBasic(testGen, opName, rank, error_name=None):
168 pl, const = opName["operands"]
169 shape = testGen.makeShape(rank)
170
171 # Constrict the overall size of the shape when creating ERROR_IF tests
172 if error_name:
173 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
174
175 shape_list = []
176 for i in range(pl + const):
177 shape_list.append(shape.copy())
178
179 if error_name == ErrorIf.RankMismatch:
180 if rank == 1 and i != 1:
181 shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
182 elif i != 1:
183 shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
184
185 return shape_list
186
187 @staticmethod
188 def tgNHWC(testGen, opName, rank, error_name=None):
189 pl, const = opName["operands"]
190
191 if error_name != ErrorIf.WrongRank:
192 assert rank == 4
193
194 shape = testGen.makeShape(rank)
195
196 # Constrict the batch size?
197 if testGen.args.max_batch_size:
198 shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
199
200 # Constrict the overall size of the shape when creating ERROR_IF tests
201 if error_name and error_name != ErrorIf.MaxDimExceeded:
202 shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
203
204 shape_list = []
205 for i in range(pl + const):
206 shape_list.append(shape.copy())
207
208 return shape_list
209
210 @staticmethod
211 def tgScatter(testGen, opName, rank, error_name=None):
212 pl, const = opName["operands"]
213
214 assert pl == 2
215 assert const == 0
216 if error_name != ErrorIf.WrongRank:
217 assert rank == 3
218
219 values_in_shape = testGen.makeShape(rank)
220
221 # ignore max batch size if target shape is set
222 if testGen.args.max_batch_size and not testGen.args.target_shapes:
223 values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
224
225 W = testGen.randInt(
226 testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
227 )
228 # Constrict W if one dimension is too large to keep tensor size reasonable
229 if max(values_in_shape) > 5000:
230 W = testGen.randInt(0, 16)
231
232 input_shape = [values_in_shape[0], W, values_in_shape[2]]
233
234 shape_list = []
235 shape_list.append(values_in_shape.copy())
236 shape_list.append(input_shape.copy())
237
238 return shape_list
239
240 @staticmethod
241 def tgBroadcastFuzz(testGen, op, rank, error_name=None):
242 shape = testGen.makeShape(rank)
243
244 pl, const = op["operands"]
245
246 shape_list = []
247
248 # Choose one of the inputs to broadcast
249 # Note: Simplifies OutputShaper code if we don't change first shape for errors
250 bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
251 for i in range(pl + const):
252 shape_bcast = shape.copy()
253
254 # If the chosen input, pick a random index to broadcast
255 if i == bcast_idx:
256 fuzz_idx = testGen.randInt(0, rank)
257 if error_name == ErrorIf.DimensionMismatch:
258 shape_bcast[fuzz_idx] += 1
259 elif error_name == ErrorIf.RankMismatch:
260 # Add one rank to the shape (or more for rank of 1)
261 extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
262 shape_bcast = np.concatenate(
263 (shape_bcast, testGen.makeShape(extra_ranks))
264 )
265 if rank != 1:
266 # Either keep the extra rank, or remove it
267 new_len = testGen.rng.choice([-2, len(shape_bcast)])
268 shape_bcast = shape_bcast[:new_len]
269 else:
270 shape_bcast[fuzz_idx] = 1
271
272 shape_list.append(shape_bcast)
273
274 return shape_list
275
276 @staticmethod
277 def tgConv2D(testGen, op, rank, error_name=None):
278 pl, const = op["operands"]
279
280 if error_name != ErrorIf.WrongRank:
281 assert rank == 4
282
283 # IFM dimensions are NHWC
284 ifm_shape = testGen.makeShape(rank)
285
286 # Constrict the batch size?
287 if testGen.args.max_batch_size:
288 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
289
290 # Constrict the overall size of the shape when creating ERROR_IF tests
291 if error_name:
292 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
293 ifm_shape, max_dim=24, max_items=10000
294 )
295
296 # Get the filter height/width from the operator parameters
297 filter_hw = op["filter"]
298
299 # Generate a random OFM depth
300 ofm_depth = testGen.makeShape(1)[0]
301
302 # The filter dimensions are OHWI
303 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
304
305 # The bias is OC
306 bias_shape = np.asarray([ofm_depth])
307
308 return [ifm_shape, filter_shape, bias_shape]
309
310 @staticmethod
311 def tgConv3D(testGen, op, rank, error_name=None):
312 pl, const = op["operands"]
313
314 if error_name != ErrorIf.WrongRank:
315 assert rank == 5
316
317 # IFM dimensions are NDHWC
318 ifm_shape = testGen.makeShape(rank)
319
320 # Constrict the batch size?
321 if testGen.args.max_batch_size:
322 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
323
324 # Constrict the overall size of the shape when creating ERROR_IF tests
325 if error_name:
326 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
327 ifm_shape, max_dim=24, max_items=10000
328 )
329
330 # Get the filter depth/height/width from the operator parameters
331 filter_dhw = op["filter"]
332
333 # Generate a random OFM channel
334 ofm_channel = testGen.makeShape(1)[0]
335
336 # The filter dimensions are ODHWI
337 filter_shape = np.asarray(
338 [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
339 )
340
341 # The bias is OC
342 bias_shape = np.asarray([ofm_channel])
343
344 return [ifm_shape, filter_shape, bias_shape]
345
346 @staticmethod
347 def tgTransposeConv2D(testGen, op, rank, error_name=None):
348 pl, const = op["operands"]
349
350 if error_name != ErrorIf.WrongRank:
351 assert rank == 4
352
353 # IFM dimensions are NHWC
354 ifm_shape = testGen.makeShape(rank)
355
356 # Constrict the batch size?
357 if testGen.args.max_batch_size:
358 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
359
360 # Constrict the overall size of the shape when creating ERROR_IF tests
361 if error_name:
362 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
363 ifm_shape, max_dim=24, max_items=10000
364 )
365
366 # Get the filter height/width from the operator parameters
367 filter_hw = op["filter"]
368
369 # Generate a random OFM depth
370 ofm_depth = testGen.makeShape(1)[0]
371
372 # The filter dimensions are OHWI
373 filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
374
375 # The bias is OC
376 bias_shape = np.asarray([ofm_depth])
377
378 return [ifm_shape, filter_shape, bias_shape]
379
380 @staticmethod
381 def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
382 pl, const = op["operands"]
383
384 if error_name != ErrorIf.WrongRank:
385 assert rank == 4
386 assert pl == 1 and const == 2
387
388 # IFM dimensions are NHWC
389 ifm_shape = testGen.makeShape(rank)
390
391 # Constrict the batch size?
392 if testGen.args.max_batch_size:
393 ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
394
395 # Constrict the overall size of the shape when creating ERROR_IF tests
396 if error_name:
397 ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
398 ifm_shape, max_dim=24, max_items=10000
399 )
400
401 # Get the filter height/width from the operator parameters
402 # Filter is KH, HW, C, M
403 filter_hw = op["filter"]
404
405 # Generate a random OFM depth, but don't let it get too big because
406 # the output depth is M * C
407 filter_m = (
408 testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
409 ) + 1
410
411 # The filter dimensions are HWCM
412 filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
413
414 # The bias is M * C
415 bias_shape = np.asarray([ifm_shape[3] * filter_m])
416
417 return [ifm_shape, filter_shape, bias_shape]
418
419 @staticmethod
420 def tgFullyConnected(testGen, op, rank, error_name=None):
421 pl, const = op["operands"]
422
423 if error_name != ErrorIf.WrongRank:
424 assert rank == 2
425
426 input_shape = testGen.makeShape(rank)
427
428 # Constrict the overall size of the shape when creating ERROR_IF tests
429 if error_name:
430 input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
431
432 filter_oc = testGen.rng.integers(
433 low=testGen.args.tensor_shape_range[0],
434 high=testGen.args.tensor_shape_range[1],
435 size=1,
436 )[0]
437 filter_shape = np.asarray([filter_oc, input_shape[1]])
438
439 bias_shape = np.asarray([filter_oc])
440
441 return [input_shape, filter_shape, bias_shape]
442
443 @staticmethod
444 def tgMatmul(testGen, op, rank, error_name=None):
445 pl, const = op["operands"]
446
447 if error_name != ErrorIf.WrongRank:
448 assert rank == 3
449 assert pl == 2 and const == 0
450
451 a_shape = testGen.makeShape(rank)
452
453 # Constrict the overall size of the shape when creating ERROR_IF tests
454 if error_name:
455 a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
456
457 # Get a random number for b_oc even if target shape is defined
458 b_oc = np.int32(
459 testGen.rng.integers(
460 low=testGen.args.tensor_shape_range[0],
461 high=testGen.args.tensor_shape_range[1],
462 size=1,
463 )
464 )[0]
465 # If N or H is large let b_oc be 1 to reduce output tensor size
466 if max(a_shape) > 1000:
467 b_oc = 1
468
469 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
470 return [a_shape, b_shape]
471
472 @staticmethod
473 def tgConcat(testGen, opName, rank, error_name=None):
474 pl, const = opName["operands"]
475 shape = testGen.makeShape(rank)
476
477 # Create extra tensors to concat.
478 # Take into account value of pl when getting maximum number of concats
479 num_tensors = testGen.randInt(0, 4)
480 shape_list = []
481 for i in range(pl + const + num_tensors):
482 if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
483 remove = testGen.rng.choice([True, False])
484 wrongShape = shape.copy()
485
486 if remove and len(shape) > 1:
487 wrongShape = wrongShape[1:]
488 else:
489 wrongShape = list(wrongShape)
490 wrongShape.append(testGen.rng.integers(1, 10))
491
492 shape_list.append(wrongShape)
493 else:
494 shape_list.append(shape.copy())
495
496 return shape_list
497
498 @staticmethod
499 def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
500 if error_name in [
501 ErrorIf.AxisSmallerZero,
502 ErrorIf.AxisLargerRank,
503 ErrorIf.ConcatInputRankMismatch,
504 ]:
505 return shapeList
506
507 # Split concat shape along axis to allow for multiple const inputs
508 # without making too many large tensors
509 if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
510 # If axis can't be split we still need to invalidate other dimensions
511 if error_name == ErrorIf.ConcatInputDimMismatch:
512 for shape in shapeList[1:]:
513 # Negative test shapeLists are created individually for each test,
514 # so no need to copy the shape before altering it.
515 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
516 return shapeList
517
518 # Create copy of shape we are going to split (so we don't alter shapeList)
519 shape = shapeList[0].copy()
520 # Add original shape as first input
521 new_shapeList = [shape.copy()]
522 length_on_axis = shape[axis]
523 remaining_length = length_on_axis
524 for i in range(len(shapeList) - 2):
525 # Calculate split on axis and remaining value
526 split_shape_val = int(shape[axis] / 2)
527 remaining_length = remaining_length - split_shape_val
528
529 # Append new shape, and set remaining shape
530 shape[axis] = split_shape_val
531 new_shapeList.append(shape.copy())
532
533 # invalidate dimensions
534 if error_name == ErrorIf.ConcatInputDimMismatch:
535 shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
536 else:
537 shape[axis] = remaining_length
538
539 if i == len(shapeList) - 3:
540 new_shapeList.append(shape.copy())
541
542 return new_shapeList
543
544
545class TosaTensorValuesGen:
546 """Tensor Value generators create the random data for each test."""
547
548 def __init__(self):
549 pass
550
551 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000552 def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100553 pCount, cCount = op["operands"]
554
555 tens = []
556 tens.extend(
557 testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
558 )
559 tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
560
561 return tens
562
563 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000564 def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson0e463642022-05-03 12:10:23 +0100565 if dtypeList[0] == DType.INT32 and error_name is None:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100566 pCount, cCount = op["operands"]
567 assert (
568 pCount == 1 and cCount == 0
569 ), "Op.NEGATE must have 1 placeholders, 0 consts"
Jeremy Johnson0e463642022-05-03 12:10:23 +0100570 # Must create tensors with values within accumulator (int32) negatable
571 # range
572 max_val = (1 << 31) - 1
573 min_val = -max_val
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100574 arr = np.int32(
575 testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
576 )
577 placeholders = []
578 placeholders.append(
579 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
580 )
581 return placeholders
582 else:
583 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000584 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100585 )
586
587 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000588 def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100589 if dtypeList[0] == DType.INT32 and error_name is None:
590 # Make sure the operation does not cause value saturation - where
591 # the number wraps due to limited number of bits to store the answer
592 pCount, cCount = op["operands"]
593 assert (
594 pCount == 2 and cCount == 0
595 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
596 placeholders = []
597 add = op["op"] == Op.ADD
598 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
599 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
600 if add:
601 res_arr = np.add(a_arr, b_arr, dtype=np.int64)
602 else:
603 res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
604
605 # Work out the saturation limits
606 max_i32 = (1 << 31) - 1
607 min_i32 = -(1 << 31)
608 max_arr = np.full(shapeList[1], max_i32)
609 min_arr = np.full(shapeList[1], min_i32)
610
611 # Find how much values exceed the maximum/minimums
612 sat_max_arr = np.maximum(res_arr - max_arr, 0)
613 sat_min_arr = np.minimum(res_arr - min_arr, 0)
614
615 if not add:
616 # Swap saturation values and negate values as we need to perform opposite operations
617 sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
618
619 # Create new array of unsaturated values by clipping values as needed
620 b_unsat_arr = b_arr
621 if (sat_max_arr != 0).any():
622 # Clip values that cause saturation
623 b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
624 # Reduce axes in unsaturated tensor to match original tensor
625 for axis, dim in enumerate(b_arr.shape):
626 if dim != b_unsat_arr.shape[axis]:
627 assert (
628 dim == 1
629 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
630 b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
631
632 if (sat_min_arr != 0).any():
633 # Clip values that cause saturation
634 b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
635 # Reduce axes in unsaturated tensor to match original tensor
636 for axis, dim in enumerate(b_arr.shape):
637 if dim != b_unsat_arr.shape[axis]:
638 assert (
639 dim == 1
640 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
641 b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
642
643 placeholders.append(
644 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
645 )
646 placeholders.append(
647 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
648 )
649
650 return placeholders
651 else:
652 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000653 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100654 )
655
656 @staticmethod
657 def tvgCondIfWhileLoop(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000658 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100659 ):
660 if dtypeList[0] in (
661 DType.INT32,
662 DType.INT16,
663 DType.INT8,
664 ):
665 # Limit input tensors with cond_if_binary or while_loop to stop
666 # saturation of add/sub ops with int32 and keep all logical shift
667 # values between 0 to 31 for int16 or int8
668 pCount, cCount = op["operands"]
669 pRemain = pCount
670 placeholders = []
671 for idx, shape in enumerate(shapeList[:]):
672 if dtypeList[0] == DType.INT32:
673 arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
674 else:
675 arr = np.int32(
676 testGen.rng.integers(low=0, high=32, size=shapeList[idx])
677 )
678 if pRemain > 0:
679 placeholders.append(
680 testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
681 )
682 pRemain -= 1
683 else:
684 placeholders.append(
685 testGen.ser.addConst(shape, dtypeList[idx], arr)
686 )
687
688 return placeholders
689 else:
690 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000691 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100692 )
693
694 @staticmethod
695 def tvgArithmeticRightShift(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000696 testGen, op, dtypeList, shapeList, testArgs, error_name=None
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100697 ):
698 pCount, cCount = op["operands"]
699 # Force value of operand[1] to be within [0, num_bits]
700 assert (
701 pCount == 2 and cCount == 0
702 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
703
704 placeholders = []
705 for idx, shape in enumerate(shapeList[:]):
706 if idx == 1:
707 if dtypeList[idx] == DType.INT8:
708 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
709 elif dtypeList[idx] == DType.INT16:
710 arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
711 elif dtypeList[idx] == DType.INT32:
712 arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
713 elif error_name == ErrorIf.WrongInputType:
714 arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
715 else:
716 raise Exception("OpArithmeticRightShift: invalid input dtype")
717 else:
718 arr = testGen.getRandTensor(shape, dtypeList[idx])
719 placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
720
721 return placeholders
722
723 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000724 def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100725 # Set datatype of condition tensor to boolean
726 dtypeList[0] = DType.BOOL
727
728 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000729 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100730 )
731
732 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000733 def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100734 if error_name is None:
735 pCount, cCount = op["operands"]
736 assert (
737 pCount == 2 and cCount == 0
738 ), "Op.INTDIV must have 2 placeholders, 0 consts"
739
740 placeholders = []
741
742 # Two invalid cases for Op.INTDIV:
743 # 1. divisor == 0
744 # 2. dividend == -(1<<31) and divisor == -1
745 while True:
746 dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
747 divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
748
749 if (divisor_arr == 0).any():
750 continue
751
752 if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
753 continue
754
755 break
756
757 placeholders.append(
758 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
759 )
760 placeholders.append(
761 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
762 )
763
764 return placeholders
765 else:
766 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000767 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100768 )
769
770 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000771 def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100772 if error_name is None:
773 pCount, cCount = op["operands"]
774 assert (
775 pCount == 2 and cCount == 0
776 ), "Op.MUL must have 2 placeholders, 0 consts"
777
778 tens = []
James Ward24dbc422022-10-19 12:20:31 +0100779 if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100780 tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
781 else:
782 placeholders = []
783
784 # Make sure multiply result in int32 range
785 shift = testArgs[0]
786 if dtypeList[0] == DType.INT8:
787 num_bits = 8
788 elif dtypeList[0] == DType.INT16:
789 num_bits = 16
790 elif dtypeList[0] == DType.INT32:
791 num_bits = 32
792 elif error_name == ErrorIf.WrongInputType:
793 num_bits = 8
794 else:
795 raise Exception("OpMul: invalid input dtype")
796
797 for idx, shape in enumerate(shapeList[:]):
798 low = -(2 ** (num_bits - 1))
799 high = (2 ** (num_bits - 1)) - 1
800
801 a_arr = np.int32(
802 testGen.rng.integers(low=low, high=high, size=shapeList[0])
803 )
804 b_arr = np.int32(
805 testGen.rng.integers(low=low, high=high, size=shapeList[1])
806 )
807
808 i = 0
809 while True:
810
811 a_arr_64 = a_arr.astype(np.int64)
812 b_arr_64 = b_arr.astype(np.int64)
813
814 if shift > 0:
815 rounding = 1 << (shift - 1)
816 result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
817 else:
818 result_arr = a_arr_64 * b_arr_64
819
820 if (result_arr > -(2**31)).all() and (
821 result_arr <= ((2**31) - 1)
822 ).all():
823 break
824
825 i = i + 1
826 a_arr = a_arr // 2
827 b_arr = b_arr // 2
828
829 placeholders.append(
830 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
831 )
832 placeholders.append(
833 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
834 )
835
836 tens.extend(placeholders)
837
838 return tens
839 else:
840 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000841 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100842 )
843
844 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000845 def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100846 count = len(shapeList) - testGen.args.num_const_inputs_concat
847 if count < 1:
848 count = 1
849 if testGen.args.num_const_inputs_concat == 0:
850 count = len(shapeList)
851
852 # Ensure axis is an int
853 testArgs[0] = int(testArgs[0])
854
855 shapeList = TosaTensorGen.tgConcatConstInput(
856 testGen, shapeList, testArgs[0], error_name
857 )
858
859 tens = []
860 tens.extend(
861 testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
862 )
863 tens.extend(testGen.buildConstTensors(shapeList[count:], dtypeList[count:]))
864
865 return tens
866
867 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000868 def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100869 pCount, cCount = op["operands"]
870 assert (
871 pCount == 2 and cCount == 0
872 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
873 values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
874 shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
875 placeholders = []
876 placeholders.append(
877 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
878 )
879 placeholders.append(
880 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
881 )
882
883 return placeholders
884
885 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000886 def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100887 if error_name is None:
888 pCount, cCount = op["operands"]
889 assert (
890 pCount == 2 and cCount == 0
891 ), "Op.EQUAL must have 2 placeholders, 0 consts"
892 a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
893 b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
894 # Using random numbers means that it will be very unlikely that
895 # there are any matching (equal) values, therefore force that
896 # there are twice the number of matching values as the tensor rank
897 for num in range(0, len(shapeList[0]) * 2):
898 a_index = []
899 b_index = []
900 # Choose an index in each axis for the whole shape
901 for axis in range(0, len(shapeList[0])):
902 # Index can be up to the largest dimension in both shapes
903 index = np.int32(
904 testGen.rng.integers(
905 0, max(shapeList[0][axis], shapeList[1][axis])
906 )
907 )
908 # Reduce the index down to a shape's dim for broadcasting
909 a_index.append(min(shapeList[0][axis] - 1, index))
910 b_index.append(min(shapeList[1][axis] - 1, index))
911
912 a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
913
914 placeholders = []
915 placeholders.append(
916 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
917 )
918 placeholders.append(
919 testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
920 )
921 return placeholders
922 else:
923 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000924 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100925 )
926
927 @staticmethod
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000928 def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100929 if dtypeList[0] == DType.INT32:
930 pCount, cCount = op["operands"]
931 assert (
932 pCount == 1 and cCount == 0
933 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
934 # Limit values so that the sum cannot exceed the range of an int32 during
935 # summation of any axis
936 range_val = int((1 << 31) / max(shapeList[0]))
937 values_arr = np.int32(
938 testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
939 )
940 placeholders = []
941 placeholders.append(
942 testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
943 )
944 return placeholders
945 else:
946 return TosaTensorValuesGen.tvgDefault(
Eric Kunzeb5fabec2022-06-07 05:20:44 +0000947 testGen, op, dtypeList, shapeList, testArgs, error_name
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100948 )
949
950
951class TosaArgGen:
952 """Argument generators create exhaustive or random lists of attributes for
953 operators that take attributes or other parameters.
954
955 The return value is a list of (descriptive_name, [arglist]) tuples where
956 the descriptive_name is appended to the test name and the arglist is expanded
957 as arguments to the operator build function.
958 """
959
960 def __init__(self):
961 pass
962
963 @staticmethod
964 def agNone(testGen, opName, shapeList, dtype, error_name=None):
965 """A trivial argument generator for operators that don't take any
966 non-tensor arguments"""
967 return [("", [])]
968
969 @staticmethod
970 def agAxis(testGen, opName, shapeList, dtype, error_name=None):
971 """Build the axis argument for operators that take a single axis"""
972 axes = []
973 shape = shapeList[0]
974
975 if error_name == ErrorIf.AxisSmallerZero:
976 small_axis = testGen.rng.integers(-5, 0)
977 axes.append(("axis{}".format(small_axis), [small_axis]))
978 elif error_name == ErrorIf.AxisLargerRank:
979 large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
980 axes.append(("axis{}".format(large_axis), [large_axis]))
981 else:
982 for a in range(0, len(shape)):
983 axes.append(("axis{}".format(a), [a]))
984
985 return axes
986
987 @staticmethod
James Ward8b390432022-08-12 20:48:56 +0100988 def agConv(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100989 arg_list = []
990
991 ifm_shape = shapeList[0]
992 filter_shape = shapeList[1]
993 # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
994 k = [int(x) for x in opName.split("_")[-1].split("x")]
995
James Ward8b390432022-08-12 20:48:56 +0100996 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
997
Jeremy Johnson9a66abb2022-04-07 11:29:20 +0100998 # Check the rank
999 rank = 5 if opName.startswith("conv3d") else 4
1000 if error_name != ErrorIf.WrongRank:
1001 assert len(ifm_shape) == rank
1002 assert len(filter_shape) == rank
1003
1004 # kernel rank omits batch and channels
1005 k_rank = rank - 2
1006 assert len(k) == k_rank
1007
1008 # Generate comprehensive argument lists
1009 # - except for named errors, which use specific invalid value(s)
1010 if error_name == ErrorIf.PadSmallerZero:
1011 p_vals = [testGen.rng.choice(range(-5, 0))]
1012 else:
1013 p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
1014 paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
1015 if error_name == ErrorIf.StrideSmallerOne:
1016 # Can't use stride=0, as it is used to derive output shape, as a divisor
1017 s_vals = [testGen.rng.choice(range(-5, 0))]
1018 else:
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001019 # Stride must be greater than 1 to force non-integer error
Jeremy Johnson93d43902022-09-27 12:26:14 +01001020 startStride = 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001021 s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001022 strides = {x for x in itertools.product(*([s_vals] * k_rank))}
1023 if error_name == ErrorIf.DilationSmallerOne:
1024 d_vals = [testGen.rng.choice(range(-5, 1))]
1025 else:
1026 d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
1027 dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
1028
1029 if not error_name and testGen.args.oversize:
1030 # add some oversize argument values
1031 if max(ifm_shape) < 64:
1032 bigPadding = 9
1033 paddings.update(
1034 {x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))}
1035 )
1036 bigStride = 8
1037 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
1038 bigDilation = 7
1039 dilations.update(
1040 {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
1041 )
1042
1043 # There are too many parameter combinations, so generate them sparsely,
1044 # very sparse for negative tests
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001045 sparsity_factor = 2 if error_name else 120
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001046 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
1047 # If there are only a small number of tests, just select them all
1048 if sparsity < 13:
1049 sparsity = 1
1050 # To get a variety of parameter combinations sparsity should not be a
1051 # multiple of 2, 3 or 5
1052 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1053 sparsity += 1
1054
1055 n = 0
1056 for s in sorted(list(strides)):
1057 for p in sorted(list(paddings)):
1058 for d in sorted(list(dilations)):
1059 if (
1060 n % sparsity == 0
Jeremy Johnson93d43902022-09-27 12:26:14 +01001061 # the padded shape must exceed the dilation * kernel to get a positive
1062 # sized output shape
1063 and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k[0] - 1)
1064 and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k[1] - 1)
1065 and (
1066 k_rank < 3
1067 or ((ifm_shape[3] - 1 + p[4] + p[5]) > d[2] * (k[2] - 1))
1068 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001069 ):
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001070 remainders = []
1071 for index in range(k_rank):
1072 pad_offset = index * 2
1073 remainders.append(
1074 (
1075 ifm_shape[index + 1]
1076 - 1
1077 + p[pad_offset]
1078 + p[pad_offset + 1]
1079 - (k[index] - 1) * d[index]
1080 )
1081 % s[index]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001082 )
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001083 if (
1084 # the parameters must produce integer exact output
1085 error_name != ErrorIf.ConvOutputShapeNonInteger
1086 and max(remainders) == 0
1087 ) or (
1088 error_name == ErrorIf.ConvOutputShapeNonInteger
1089 and max(remainders) > 0
1090 ):
1091 arg_list.append(
1092 (
James Ward8b390432022-08-12 20:48:56 +01001093 "acc{}_st{}_pad{}_dilat{}".format(
1094 testGen.typeStr(accum_dtype),
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001095 "".join([str(x) for x in s]),
1096 "".join([str(x) for x in p]),
1097 "".join([str(x) for x in d]),
1098 ),
James Ward8b390432022-08-12 20:48:56 +01001099 [accum_dtype, s, p, d],
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001100 )
1101 )
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001102 n += 1
1103
1104 return arg_list
1105
1106 @staticmethod
James Ward8b390432022-08-12 20:48:56 +01001107 def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
1108
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001109 assert isinstance(dtypes, list) or isinstance(
1110 dtypes, tuple
1111 ), f"{dtypes} unexpected"
1112 input_dtype = dtypes[0]
James Ward8b390432022-08-12 20:48:56 +01001113
1114 if error_name == ErrorIf.WrongOutputType:
1115 accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype)
1116 elif error_name == ErrorIf.WrongInputType:
1117 # Pick some potentially correct output dtype if input type is incorrect
1118 accum_dtype = DType.INT32
1119 else:
1120 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1121
1122 return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])]
1123
1124 @staticmethod
1125 def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
1126 # Get valid accumulate type(s)
1127 if dtype == DType.INT8:
1128 accum_dtypes = [DType.INT32]
1129 elif dtype == DType.INT16:
1130 accum_dtypes = [DType.INT48]
1131 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001132 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001133 elif dtype == DType.BF16:
1134 accum_dtypes = [DType.FP32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001135 elif dtype == DType.FP32:
1136 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001137 elif error_name is None:
1138 assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
1139
1140 if error_name == ErrorIf.WrongOutputType:
1141 # Get incorrect output dtype for ErrorIf case
1142 accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)]
1143 elif error_name == ErrorIf.WrongInputType:
1144 # Pick some potentially correct output dtype if input type is incorrect
1145 accum_dtypes = [DType.INT32]
1146
1147 return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes]
1148
1149 @staticmethod
1150 def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001151 arg_list = []
1152
1153 ifm_shape = shapeList[0]
1154 filter_shape = shapeList[1]
1155
James Ward8b390432022-08-12 20:48:56 +01001156 accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
1157
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001158 # Must be rank 4
1159 if error_name != ErrorIf.WrongRank:
1160 assert len(ifm_shape) == 4
1161 assert len(filter_shape) == 4
1162
1163 # Generate comprehensive argument lists
1164 # - except for named errors, which use specific invalid value(s)
Eric Kunzec1a97832022-07-01 16:56:09 -07001165 smallest_padding_size = -min(filter_shape[1], filter_shape[2]) + 1
1166 if error_name == ErrorIf.PadLargerEqualKernel:
1167 max_filter_size = -max(filter_shape[1], filter_shape[2])
1168 p_vals = [testGen.rng.choice(range(max_filter_size - 10, max_filter_size))]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001169 else:
Eric Kunzec1a97832022-07-01 16:56:09 -07001170 p_vals = [
1171 x
1172 for x in range(smallest_padding_size, testGen.args.max_conv_padding + 1)
1173 ]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001174 paddings = {x for x in itertools.product(*([p_vals] * 4))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001175 if error_name == ErrorIf.StrideSmallerOne:
1176 # Can't use stride=0, as it is used to derive output shape, as a divisor
1177 s_vals = [testGen.rng.choice(range(-5, 0))]
1178 else:
1179 s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
1180 strides = {x for x in itertools.product(*([s_vals] * 2))}
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001181
Jeremy Johnson5860df62022-05-04 15:30:58 +01001182 if not error_name and testGen.args.oversize:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001183 # add some oversize argument values
1184 if max(ifm_shape) < 64:
1185 bigPadding = 9
1186 paddings.update(
Eric Kunzec1a97832022-07-01 16:56:09 -07001187 {
1188 x
1189 for x in itertools.product(
1190 *([[smallest_padding_size, bigPadding]] * 4)
1191 )
1192 }
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001193 )
1194 bigStride = 8
1195 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001196
1197 # There are too many parameter combinations, so generate them sparsely,
1198 # very sparse for negative tests
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001199 sparsity_factor = 2 if error_name else 10
TatWai Chong24594f52022-06-08 00:48:04 -07001200 sparsity = len(paddings) * len(strides) // sparsity_factor + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001201 # If there are only a small number of tests, just select them all
1202 if sparsity < 13:
1203 sparsity = 1
1204 # To get a variety of parameter combinations sparsity should not be a
1205 # multiple of 2, 3 or 5
1206 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
1207 sparsity += 1
1208
1209 n = 0
1210 for s in sorted(list(strides)):
1211 for p in sorted(list(paddings)):
TatWai Chong24594f52022-06-08 00:48:04 -07001212 if n % sparsity == 0:
1213 # Determine the output shape
Eric Kunzec1a97832022-07-01 16:56:09 -07001214 oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + filter_shape[1]
1215 ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + filter_shape[2]
TatWai Chong24594f52022-06-08 00:48:04 -07001216 os = [ifm_shape[0], oh, ow, filter_shape[0]]
1217 arg_list.append(
1218 (
James Ward8b390432022-08-12 20:48:56 +01001219 "acc{}_st{}_pad{}_os{}".format(
1220 testGen.typeStr(accum_dtype),
TatWai Chong24594f52022-06-08 00:48:04 -07001221 "".join([str(x) for x in s]),
1222 "".join([str(x) for x in p]),
1223 "x".join([str(x) for x in os]),
1224 ),
James Ward8b390432022-08-12 20:48:56 +01001225 [accum_dtype, s, p, os],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001226 )
TatWai Chong24594f52022-06-08 00:48:04 -07001227 )
1228 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001229
1230 return arg_list
1231
1232 @staticmethod
1233 def agPad(testGen, opName, shapeList, dtype, error_name=None):
1234 arg_list = []
1235 rank = len(shapeList[0])
1236
1237 # Exhaustively test combinations of padding on each side of each dimension
1238 # - the range of padding values is defined by pad_min and pad_max
1239 # - for padding >9, the name format needs to be more distinctive
1240 pad_min, pad_max = 0, 1
1241 pad_values = [x for x in range(pad_min, pad_max + 1)]
1242 if error_name == ErrorIf.PadSmallerZero:
1243 pad_values = [x for x in range(-2, 0)]
1244 axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
1245 shape_pad_values = itertools.product(*([axis_pad_values] * rank))
1246
1247 if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
1248 pad_const_int = testGen.getRandNumberDType(dtype)
1249 pad_const_fp = 0
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001250 elif dtype in (DType.FP16, DType.FP32):
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001251 pad_const_int = 0
1252 pad_const_fp = testGen.getRandNumberDType(dtype)
1253 else:
1254 return []
1255
1256 for paddings in shape_pad_values:
James Ward8b390432022-08-12 20:48:56 +01001257 paddings = list(paddings)
1258 args_valid = True
1259
1260 if error_name == ErrorIf.PadSmallerZero:
1261 # Prevent negative output shapes while ensuring still testing for negative padding
1262 for i in range(rank):
1263 dim_after_padding = (
1264 paddings[i][0] + paddings[i][1] + shapeList[0][i]
1265 )
1266 if dim_after_padding < 1:
1267 paddings[i] = (0, 0)
1268 if all([p > -1 for p in paddings[i]]):
1269 args_valid = False
1270
1271 if args_valid:
1272 name = "pad"
1273 for r in range(rank):
1274 before, after = paddings[r]
1275 name = f"{name}{before}{after}"
1276 arg_list.append(
1277 (name, [np.array(paddings), pad_const_int, pad_const_fp])
1278 )
1279
1280 if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
1281 warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001282
1283 return arg_list
1284
1285 @staticmethod
1286 def agPooling(testGen, opName, shapeList, dtype, error_name=None):
1287 arg_list = []
1288
1289 shape = shapeList[0]
1290 if error_name != ErrorIf.WrongRank:
1291 assert len(shape) == 4
1292
1293 # Generate comprehensive argument lists
1294 p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
1295 paddings = {x for x in itertools.product(*([p_vals] * 4))}
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001296 # Stride must be greater than 1 to force non-integer error
1297 startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
1298 s_vals = [x for x in range(startStride, testGen.args.max_pooling_stride + 1)]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001299 strides = {x for x in itertools.product(*([s_vals] * 2))}
1300 k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)]
1301 kernels = {x for x in itertools.product(*([k_vals] * 2))}
1302
James Ward8b390432022-08-12 20:48:56 +01001303 if opName == "max_pool2d":
1304 accum_dtypes = [None] # max_pool has no accumulate dtype
1305 elif dtype == DType.INT8 or dtype == DType.INT16:
1306 accum_dtypes = [DType.INT32]
1307 elif dtype == DType.FP16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001308 accum_dtypes = [DType.FP16, DType.FP32]
James Ward24dbc422022-10-19 12:20:31 +01001309 elif dtype == DType.BF16 or dtype == DType.FP32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001310 accum_dtypes = [DType.FP32]
James Ward8b390432022-08-12 20:48:56 +01001311 elif error_name is None:
1312 assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
1313 else:
1314 # Set to something for the ErrorIf case which has
1315 # incorrect input data-type
1316 accum_dtypes = [DType.INT32]
1317
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001318 if testGen.args.oversize:
1319 # add some oversize argument values
1320 bigStride = 7
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001321 strides.update(
1322 {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
1323 )
1324 bigKernel = 9
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001325 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
1326 if max(shape) < 64:
1327 # padding must be less than the kernel size
1328 bigPadding = bigKernel - 1
1329 paddings.update(
1330 {x for x in itertools.product(*([[0, bigPadding]] * 4))}
1331 )
1332
1333 # There are too many parameter combinations, so generate them sparsely,
1334 # very sparse for negative tests
1335 sparsity_factor = 2 if error_name else 500
1336 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
1337
James Ward8b390432022-08-12 20:48:56 +01001338 arg_str = (
1339 "acc{}_st{}_kern{}_pad{}"
1340 if accum_dtypes[0] is not None
1341 else "st{}_kern{}_pad{}"
1342 )
1343
1344 def get_arg_list_element(accum, stride, pad, kern):
1345 # Return tuple containing the formatted argument string and
1346 # the corresponding argument values
1347 arg_str_elems = [
1348 "".join([str(x) for x in stride]),
1349 "".join([str(x) for x in kern]),
1350 "".join([str(x) for x in pad]),
1351 ]
1352 # Note: different order to string
1353 arg_val_elems = [stride, pad, kern]
1354
1355 if accum is not None:
1356 arg_str_elems.insert(0, testGen.typeStr(accum))
1357 arg_val_elems.insert(0, accum)
1358 return (arg_str.format(*arg_str_elems), arg_val_elems)
1359
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001360 n = 0
James Ward8b390432022-08-12 20:48:56 +01001361 for a in accum_dtypes:
1362 for s in sorted(list(strides)):
1363 for p in sorted(list(paddings)):
1364 for k in sorted(list(kernels)):
1365 if error_name in [
1366 ErrorIf.StrideSmallerOne,
1367 ErrorIf.KernelSmallerOne,
1368 ErrorIf.PadSmallerZero,
1369 ErrorIf.PadLargerEqualKernel,
1370 ]:
1371 sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
1372 testGen, error_name, s, p, k
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001373 )
James Ward8b390432022-08-12 20:48:56 +01001374 if None not in [sNew, pNew, kNew] and n % sparsity == 0:
1375 arg_vals = [a, sNew, pNew, kNew]
1376 arg_list.append(get_arg_list_element(*arg_vals))
1377 elif (
1378 n % sparsity == 0
1379 # padding must not exceed the kernel size
1380 and p[0] < k[0]
1381 and p[1] < k[0]
1382 and p[2] < k[1]
1383 and p[3] < k[1]
1384 # the padded shape must exceed the kernel size
1385 and (shape[1] + p[0] + p[1]) > k[0]
1386 and (shape[2] + p[2] + p[3]) > k[1]
Jeremy Johnson4a6fb9b2022-04-26 15:47:21 +01001387 ):
James Ward8b390432022-08-12 20:48:56 +01001388 remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0]
1389 remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1]
1390 if (
1391 # the parameters must produce integer exact output
1392 error_name != ErrorIf.PoolingOutputShapeNonInteger
1393 and remainder_h == 0
1394 and remainder_w == 0
1395 ) or (
1396 error_name == ErrorIf.PoolingOutputShapeNonInteger
1397 and (remainder_h != 0 or remainder_w != 0)
1398 ):
1399 arg_vals = [a, s, p, k]
1400 arg_list.append(get_arg_list_element(*arg_vals))
1401 n += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001402
1403 return arg_list
1404
1405 @staticmethod
1406 def agCast(testGen, opName, shapeList, inDtype, error_name=None):
1407 arg_list = []
1408
1409 # Enumerate the output types here
1410 if error_name == ErrorIf.WrongOutputType:
1411 dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
1412 elif inDtype == DType.INT8:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001413 dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001414 elif inDtype == DType.INT16:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001415 dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001416 elif inDtype == DType.INT32:
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001417 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001418 elif inDtype == DType.BOOL:
1419 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward8b390432022-08-12 20:48:56 +01001420 elif inDtype == DType.FP16:
1421 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
James Ward24dbc422022-10-19 12:20:31 +01001422 elif inDtype == DType.BF16:
1423 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001424 elif inDtype == DType.FP32:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001425 dtypeList = [DType.INT8, DType.INT16, DType.INT32]
1426 elif error_name == ErrorIf.WrongInputType:
1427 # Pick some potentially correct output type for incorrect input type
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001428 dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001429 else:
1430 raise Exception("Unexpected input dtype: {}".format(inDtype))
1431
1432 for dtype in dtypeList:
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001433 arg_list.append(("out{}".format(testGen.typeStr(dtype)), [dtype]))
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001434
1435 return arg_list
1436
1437 @staticmethod
1438 def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
1439 arg_list = []
1440
1441 # Enumerate the output types here
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001442 for outDtype in [
1443 DType.UINT8,
1444 DType.INT8,
1445 DType.INT16,
1446 DType.INT32,
1447 DType.UINT16,
1448 ]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001449 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001450 outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001451 and error_name == ErrorIf.OutputZeroPointNotZero
1452 ):
1453 continue
1454 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001455 outDtype != DType.UINT16
1456 and error_name == ErrorIf.U16OutputZeroPointNotValid
1457 ) or (
1458 inDtype != DType.UINT16
1459 and error_name == ErrorIf.U16InputZeroPointNotValid
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001460 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001461 # ErrorIfs only valid with UINT16
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001462 continue
1463 if (
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001464 inDtype == DType.UINT8
1465 and outDtype not in [DType.INT8, DType.INT16]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001466 and error_name != ErrorIf.WrongOutputType
1467 ):
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001468 # The only output dtypes for UINT8 are INT8/INT16, skip all others
1469 continue
1470 if (
1471 inDtype not in [DType.INT8, DType.INT16]
1472 and outDtype == DType.UINT8
1473 and error_name != ErrorIf.WrongOutputType
1474 ):
1475 # The only input dtypes for UINT8 are INT8/INT16, skip all others
1476 continue
1477 if (
1478 inDtype == DType.UINT16
1479 and outDtype != DType.INT16
1480 and error_name != ErrorIf.WrongOutputType
1481 ):
1482 # The only output dtype for UINT16 is INT16, skip all others
1483 continue
1484 if (
1485 inDtype != DType.INT16
1486 and outDtype == DType.UINT16
1487 and error_name != ErrorIf.WrongOutputType
1488 ):
1489 # The only input dtype for UINT16 is INT16, skip all others
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001490 continue
1491 if (
1492 error_name == ErrorIf.WrongOutputType
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001493 and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001494 ):
1495 continue
1496
1497 for scale32 in [False, True]:
1498 if error_name == ErrorIf.ScaleTrue and not scale32:
1499 continue
1500 elif error_name == ErrorIf.ScaleNotTrue and scale32:
1501 continue
1502 for double_round in [False, True]:
1503 if error_name == ErrorIf.ScaleNotTrue and not double_round:
1504 continue
1505 for per_channel in [False, True]:
1506
1507 if (
1508 inDtype == DType.INT48
1509 and scale32
1510 and error_name != ErrorIf.ScaleTrue
1511 ):
1512 # Illegal condition. Must be scale32=False
1513 continue
1514 if (
1515 double_round
1516 and not scale32
1517 and error_name != ErrorIf.ScaleNotTrue
1518 ):
1519 # Illegal condition. ERROR_IF(!scale32 && double_round)
1520 continue
1521
1522 arg_list.append(
1523 (
1524 "out{}_sc{}_dr{}_pc{}".format(
Jeremy Johnson3b0544c2022-10-18 16:32:19 +01001525 testGen.typeStr(outDtype),
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001526 int(scale32),
1527 int(double_round),
1528 int(per_channel),
1529 ),
Jeremy Johnsonf7f78ae2022-05-25 15:26:38 +01001530 [outDtype, scale32, double_round, per_channel],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001531 )
1532 )
1533
1534 return arg_list
1535
1536 @staticmethod
1537 def agMul(testGen, opName, shapeList, dtype, error_name=None):
1538 arg_list = []
1539
1540 if dtype is DType.INT32:
1541 for p in range(testGen.args.num_rand_permutations):
1542
1543 shift = testGen.randInt(0, 32)
1544
1545 arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
1546 else:
1547 arg_list.append(("perm0_shift0", [0]))
1548
1549 return arg_list
1550
1551 @staticmethod
1552 def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
1553 arg_list = []
1554
1555 arg_list.append(("roundTrue", [True]))
1556 arg_list.append(("roundFalse", [False]))
1557
1558 return arg_list
1559
1560 # Helper function for reshape. Gets some factors of a larger number.
1561 @staticmethod
1562 def getFactors(val, start=1):
1563 factors = []
1564
1565 for i in range(start, int(np.sqrt(val)) + 1):
1566 if (val % i) == 0:
1567 factors.append(i)
1568
1569 return factors
1570
1571 @staticmethod
1572 def agReshape(testGen, opName, shapeList, dtype, error_name=None):
1573 arg_list = []
1574
1575 origShape = shapeList[0]
1576
1577 totalElements = 1
1578 for s in origShape:
1579 totalElements *= s
1580
1581 # This code is NOT fast. Fortunately, the numbers are fairly small.
1582 factors = TosaArgGen.getFactors(totalElements)
1583
1584 for p in range(testGen.args.num_rand_permutations):
1585 newRank = testGen.randInt(1, 7)
1586 if len(factors) < newRank:
1587 continue
1588
1589 found = True
1590 # escape_counter breaks while loop if it continues on for too long
1591 escape_counter = 0
1592 while found:
1593 newShape = []
1594 # Generate newShape ensuring it isn't a duplicate
1595 remainingElements = totalElements
1596 shuffledFactors = testGen.rng.permutation(factors)
1597 for i in range(1, newRank):
1598 # pick rank-1 factors
1599 newShape.append(shuffledFactors[0])
1600 remainingElements = remainingElements // shuffledFactors[0]
1601 shuffledFactors = testGen.rng.permutation(
1602 TosaArgGen.getFactors(remainingElements)
1603 )
1604 newShape.append(remainingElements)
1605
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001606 # Check for duplicates
1607 found = False
1608 for name, other_shape in arg_list:
1609 if other_shape[0] == newShape:
1610 found = True
1611 break
1612
1613 escape_counter += 1
1614 if escape_counter >= 100:
1615 break
1616
1617 if not found:
1618 arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
1619
1620 return arg_list
1621
1622 @staticmethod
1623 def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
1624 arg_list = []
1625
1626 ifm_shape = shapeList[0]
1627
1628 if error_name == ErrorIf.IndexOutsideBounds:
1629 incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
1630 incorrect_small_index = range(-len(ifm_shape), 0)
1631 permutations = [p for p in itertools.permutations(incorrect_large_index)]
1632 permutations.extend(
1633 [p for p in itertools.permutations(incorrect_small_index)]
1634 )
1635 elif error_name == ErrorIf.IndexUsedTwice:
1636 # Create list with a duplicated index
1637 perm_range = list(range(len(ifm_shape)))
1638 index_choice = testGen.rng.choice(range(len(perm_range)))
1639 perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
1640 permutations = [p for p in itertools.permutations(perm_range)]
1641
1642 else:
1643 # Get all permutations
1644 permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
1645
1646 # Limit to possible permutations from shape dimension or argument setting
1647 limit = min(len(permutations), testGen.args.num_rand_permutations)
1648
1649 # Get random permutation generator that uses all permutations
1650 random_permutations = testGen.rng.permutation(permutations)
1651
1652 # Create list of required amount of permutations
1653 arg_list = [
1654 ("perm{}".format(p), [random_permutations[p].tolist()])
1655 for p in range(limit)
1656 ]
1657 return arg_list
1658
1659 @staticmethod
1660 def agSlice(testGen, opName, shapeList, dtype, error_name=None):
1661 arg_list = []
1662
1663 ifm_shape = shapeList[0]
1664 rank = len(ifm_shape)
1665
1666 for p in range(testGen.args.num_rand_permutations):
1667 start = []
1668 size = []
1669
1670 valid = True
1671
1672 for i in range(rank):
1673 if ifm_shape[i] > 1:
1674 start.append(testGen.randInt(0, ifm_shape[i]))
1675 size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
1676
1677 # Invalid slice size?
1678 if size[i] == 0:
1679 valid = False
1680 else:
1681 start.append(0)
1682 size.append(1)
1683
1684 if valid:
1685 # If ERROR_IF test required then incorrect start, size will be returned
1686 start, size = TosaErrorIfArgGen.eiSliceErrorIf(
1687 testGen, error_name, ifm_shape, start, size
1688 )
1689 arg_list.append(("perm{}".format(p), [start, size]))
1690 return arg_list
1691
1692 @staticmethod
1693 def agTile(testGen, opName, shapeList, dtype, error_name=None):
1694 arg_list = []
1695
1696 ifm_shape = shapeList[0]
1697 rank = len(ifm_shape)
1698
1699 for p in range(testGen.args.num_rand_permutations):
1700
1701 # Pick a few random, but small multiple values
1702 # because otherwise this has a tendency to generate
1703 # enormous tensors
1704 multiples = []
1705 for i in range(rank):
1706 if ifm_shape[i] > 1000:
1707 # Multiple of 1 if ifm_shape dimension is large to reduce
1708 # tensor size
1709 multiples.append(1)
1710 elif max(ifm_shape) > 1000:
1711 multiples.append(2)
1712 else:
1713 multiples.append(testGen.randInt(1, 4))
1714 arg_list.append(("perm{}".format(p), [multiples]))
1715
1716 return arg_list
1717
1718 @staticmethod
1719 def agResize(testGen, opName, shapeList, dtype, error_name=None):
1720 arg_list = []
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001721 ifm_shape = shapeList[0]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001722
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001723 def get_aspect_ratio_resize_params():
1724 common_aspect_ratios = ((3, 2), (16, 9), (4, 3))
1725 aspect_ratio = testGen.rng.choice(common_aspect_ratios)
1726 invert = testGen.rng.choice((False, True))
1727 letterbox = testGen.rng.choice((False, True))
1728
1729 scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1]
1730 scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0]
1731 scale_y_d = scale_x_d = 1
1732 offset_x = offset_y = 0
1733
1734 if letterbox:
1735 max_border = scale_y_n
1736 border_y = testGen.randInt(low=0, high=max_border)
1737 border_x = 0
1738 else:
1739 # Pillarboxing
1740 border_y = 0
1741 max_border = scale_x_n
1742 border_x = testGen.randInt(low=0, high=max_border)
1743
1744 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1745 offset = (offset_y, offset_x)
1746 border = (border_y, border_x)
1747
1748 return scale, offset, border
1749
1750 def get_upscale_downscale_params():
1751 valid_params = False
1752 while not valid_params:
1753 upscale = testGen.rng.choice((False, True))
1754
1755 # True if sampling begins from (0,0). Otherwise (-0.5,-0.5)
1756 origin_sampling = testGen.rng.choice((False, True))
1757
1758 if upscale:
1759 shift = testGen.randInt(low=1, high=4)
1760 scale_x_d = scale_y_d = 1
1761 scale_x_n = scale_y_n = (
1762 1 << shift if origin_sampling else 2 << shift
1763 )
1764 border_x = border_y = 0 if origin_sampling else (1 << shift) - 1
1765 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1
1766 else:
1767 scale_x_n = 1
1768 scale_y_n = 1
1769
1770 # Return list of valid scale_*_d values (max value 4) given input dim shape
1771 def get_valid_denom(ifm_dim):
1772 return [x for x in range(1, 5) if ifm_dim % x == 1]
1773
1774 # Generate list of valid downscale values and choose one randomly
1775 valid_scale_y_ds = get_valid_denom(ifm_shape[1])
1776 valid_scale_x_ds = get_valid_denom(ifm_shape[2])
1777
1778 if not valid_scale_y_ds and not valid_scale_x_ds:
1779 # Bad parameters, skip
1780 continue
1781
1782 if not valid_scale_y_ds:
1783 scale_y_d = 1
1784 else:
1785 scale_y_d = testGen.rng.choice(valid_scale_y_ds)
1786
1787 if not valid_scale_x_ds:
1788 scale_x_d = 1
1789 else:
1790 scale_x_d = testGen.rng.choice(valid_scale_x_ds)
1791
1792 border_x = border_y = 0
1793 offset_y = testGen.randInt(0, 16 * scale_y_n)
1794 offset_x = testGen.randInt(0, 16 * scale_x_n)
1795 valid_params = True
1796
1797 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1798 offset = (offset_y, offset_x)
1799 border = (border_y, border_x)
1800 return scale, offset, border
1801
1802 def get_rand_params():
1803 # Scale
1804 scale_y_n = testGen.randInt(low=1, high=(1 << 11))
1805 scale_x_n = testGen.randInt(low=1, high=(1 << 11))
1806
1807 scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n))
1808 scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n))
1809
1810 # Offsets and border within the scale
1811 offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n))
1812 offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n))
1813 border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n)
1814 border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n)
1815
1816 scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d)
1817 offset = (offset_y, offset_x)
1818 border = (border_y, border_x)
1819 return scale, offset, border
1820
1821 for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001822 # Exclude illegal {mode, type} configurations. Pick legal output types
1823 if mode == ResizeMode.NEAREST and dtype == DType.INT8:
1824 outputDTypeList = [DType.INT8]
1825 elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
1826 outputDTypeList = [DType.INT16]
1827 elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
1828 outputDTypeList = [DType.INT32]
1829 elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
1830 outputDTypeList = [DType.INT48]
James Ward8b390432022-08-12 20:48:56 +01001831 elif dtype == DType.FP16:
1832 outputDTypeList = [DType.FP16]
James Ward24dbc422022-10-19 12:20:31 +01001833 elif dtype == DType.BF16:
1834 outputDTypeList = [DType.BF16]
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +01001835 elif dtype == DType.FP32:
1836 outputDTypeList = [DType.FP32]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001837 elif error_name == ErrorIf.WrongInputType:
1838 # If an incorrect input type is used then we set a 'correct'
1839 # output type to avoid other errors
1840 outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
1841 else:
1842 continue
1843
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001844 arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}"
1845
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001846 for outputDType in outputDTypeList:
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001847 perm = 0
1848 while perm < testGen.args.num_rand_permutations:
1849 # Random choice of type of params we are testing
1850 _rnd_param_fn = testGen.rng.choice(
1851 (
1852 get_rand_params,
1853 get_upscale_downscale_params,
1854 get_aspect_ratio_resize_params,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001855 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001856 )
1857 scale, offset, border = _rnd_param_fn()
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001858
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001859 # Expand params for bounds-checking
1860 (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale
1861 (offset_y, offset_x) = offset
1862 (border_y, border_x) = border
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001863
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001864 # Make sure output dimensions OH and OW are integers
1865 partial_output_y = (
1866 (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y
1867 )
1868 partial_output_x = (
1869 (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x
1870 )
1871 if error_name == ErrorIf.ResizeOutputShapeNonInteger:
1872 if (
1873 partial_output_y % scale_y_d == 0
1874 and partial_output_x % scale_x_d == 0
1875 ):
1876 # Skip this test as it doesn't produce NonInteger output
1877 perm += 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001878 continue
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001879 else:
1880 while partial_output_y % scale_y_d != 0:
1881 scale_y_d -= 1
1882 while partial_output_x % scale_x_d != 0:
1883 scale_x_d -= 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001884
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001885 output_y = partial_output_y // scale_y_d + 1
1886 output_x = partial_output_x // scale_x_d + 1
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001887
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001888 if (
1889 output_y >= testGen.args.max_resize_output_dim
1890 or output_x >= testGen.args.max_resize_output_dim
1891 ) and error_name is None:
1892 # Skip positive test if output dim will be too high
1893 # Avoid high test latency and OOM issues
1894 perm += 1
1895 continue
1896
1897 if (
1898 output_y <= 0
1899 or output_y >= MAX_RESIZE_DIMENSION
1900 or output_x <= 0
1901 or output_x >= MAX_RESIZE_DIMENSION
1902 ):
1903 # Output dimensions out of scope
1904 if error_name is not None and perm > 0:
1905 # As long as we have one ERROR_IF test, don't worry
1906 # about creating all the other permutations
1907 perm += 1
1908 continue
1909
1910 if error_name == ErrorIf.ResizeOutputShapeMismatch and (
1911 (
1912 output_y + scale_y_d >= MAX_RESIZE_DIMENSION
1913 and output_y - scale_y_d < 1
1914 )
1915 or (
1916 output_x + scale_x_d >= MAX_RESIZE_DIMENSION
1917 and output_x - scale_x_d < 1
1918 )
1919 ):
1920 # Can't create a negative test with these params as it
1921 # will create invalid output size
1922 if perm > 0:
1923 perm += 1
1924 continue
1925
1926 scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d]
1927 offset = [offset_y, offset_x]
1928 border = [border_y, border_x]
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001929
1930 # Common for all data types
1931 if error_name is not None:
1932 (
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001933 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001934 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001935 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001936 outputDTypeNew,
1937 ) = TosaErrorIfArgGen.eiResizeErrorIf(
1938 testGen,
1939 error_name,
1940 mode,
1941 dtype,
1942 shapeList,
1943 outputDType,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001944 scale,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001945 offset,
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001946 border,
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001947 )
1948 else:
1949 outputDTypeNew = outputDType
1950
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001951 arg_to_append = (
1952 arg_str.format(
1953 "N" if mode == ResizeMode.NEAREST else "B",
1954 testGen.typeStr(outputDTypeNew),
1955 scale[0],
1956 scale[1],
1957 scale[2],
1958 scale[3],
1959 offset[0],
1960 offset[1],
1961 border[0],
1962 border[1],
1963 ),
1964 [
1965 mode,
1966 scale,
1967 offset,
1968 border,
1969 dtype,
1970 outputDTypeNew,
1971 ],
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001972 )
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001973 if arg_to_append in arg_list:
1974 # Skip already generated test params
1975 continue
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001976
Jeremy Johnsona0e03f32022-06-13 17:48:09 +01001977 # Valid permutation
1978 perm += 1
1979 arg_list.append(arg_to_append)
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01001980 return arg_list
1981
1982 @staticmethod
1983 def agTable(testGen, opName, shapeList, dtype, error_name=None):
1984 arg_list = []
1985
1986 if dtype == DType.INT8:
1987 table = np.int32(
1988 testGen.rng.integers(low=-128, high=128, size=[256])
1989 ).tolist()
1990 else: # INT16
1991 table = np.int32(
1992 testGen.rng.integers(low=-32768, high=32768, size=[513])
1993 ).tolist()
Jerry Ged511f9e2022-08-12 16:12:40 -07001994 # Make sure all slopes are within REQUIRE min/max 16-bit int
1995 for idx in range(len(table) - 1):
1996 slope = table[idx + 1] - table[idx]
1997 # Alter the next table entry to force the slope to be ok
1998 if slope > 32767:
1999 table[idx + 1] -= slope - 32767
2000 if slope < -32768:
2001 table[idx + 1] -= slope + 32768
2002 slope = table[idx + 1] - table[idx]
2003 assert slope <= 32767 and slope >= -32768
Jeremy Johnson9a66abb2022-04-07 11:29:20 +01002004 arg_list.append(
2005 (
2006 "",
2007 [table],
2008 )
2009 )
2010 return arg_list
2011
2012 def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
2013 # CondIf generates the condition values here.
2014 # Convert to tensors in the build function, along with the
2015 # then and else blocks
2016 arg_list = []
2017
2018 for c in [False, True]:
2019 arg_list.append(("cond{}".format(int(c)), [c]))
2020
2021 return arg_list
2022
2023 def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
2024 # While loop: 0 iterations, 1, more than 1
2025 arg_list = []
2026
2027 for iter in [0, 1, 4]:
2028 arg_list.append(("iter{}".format(iter), [iter]))
2029
2030 return arg_list